@@ -522,10 +522,43 @@ def c_support_code(self, **kwargs):
522522 # In that case we add the 'int' type to the real types.
523523 real_types .append ("int" )
524524
525+ # Macros for backwards compatibility with numpy < 2.0
526+ #
527+ # In numpy 2.0+, these are defined in npy_math.h, but
528+ # for early versions, they must be vendored by users (e.g. PyTensor)
529+ backwards_compat_macros = """
530+ #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
531+ #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
532+
533+ #include <numpy/npy_math.h>
534+
535+ #ifndef NPY_CSETREALF
536+ #define NPY_CSETREALF(c, r) (c)->real = (r)
537+ #endif
538+ #ifndef NPY_CSETIMAGF
539+ #define NPY_CSETIMAGF(c, i) (c)->imag = (i)
540+ #endif
541+ #ifndef NPY_CSETREAL
542+ #define NPY_CSETREAL(c, r) (c)->real = (r)
543+ #endif
544+ #ifndef NPY_CSETIMAG
545+ #define NPY_CSETIMAG(c, i) (c)->imag = (i)
546+ #endif
547+ #ifndef NPY_CSETREALL
548+ #define NPY_CSETREALL(c, r) (c)->real = (r)
549+ #endif
550+ #ifndef NPY_CSETIMAGL
551+ #define NPY_CSETIMAGL(c, i) (c)->imag = (i)
552+ #endif
553+
554+ #endif
555+ """
556+
525557 def _make_get_set_real_imag (scalar_type : str ) -> str :
526558 """Make overloaded getter/setter functions for real/imag parts of numpy complex types.
527559
528- The functions called by these getter/setter functions are defining in npy_math.h
560+ The functions called by these getter/setter functions are defining in npy_math.h, or
561+ in the `backward_compat_macros` defined above.
529562
530563 Args:
531564 scalar_type: float, double, or longdouble
@@ -536,11 +569,11 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
536569 """
537570 complex_type = "npy_c" + scalar_type
538571 suffix = "" if scalar_type == "double" else scalar_type [0 ]
539- return_type = scalar_type
540572
541573 if scalar_type == "longdouble" :
542- scalar_type += "_t"
543- return_type = "npy_" + return_type
574+ scalar_type = "npy_" + scalar_type
575+
576+ return_type = scalar_type
544577
545578 template = f"""
546579 static inline { return_type } get_real(const { complex_type } z)
@@ -550,7 +583,7 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
550583
551584 static inline void set_real({ complex_type } *z, const { scalar_type } r)
552585 {{
553- npy_csetreal { suffix } (z, r);
586+ NPY_CSETREAL { suffix . upper () } (z, r);
554587 }}
555588
556589 static inline { return_type } get_imag(const { complex_type } z)
@@ -560,17 +593,28 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
560593
561594 static inline void set_imag({ complex_type } *z, const { scalar_type } i)
562595 {{
563- npy_csetimag { suffix } (z, i);
596+ NPY_CSETIMAG { suffix . upper () } (z, i);
564597 }}
565598 """
566599 return template
567600
568- # TODO: add guard code to prevent this from being defining twice, in case we need to add it somewhere else
569601 get_set_aliases = "\n " .join (
570602 _make_get_set_real_imag (stype )
571603 for stype in ["float" , "double" , "longdouble" ]
572604 )
573605
606+ get_set_aliases = backwards_compat_macros + "\n " + get_set_aliases
607+
608+ # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
609+ #
610+ # The npy_complex64, npy_complex128 types are aliases defined at run time based on
611+ # the size of floats and doubles on the machine. This means that both types are
612+ # not necessarily defined on every machine, but a machine with 32-bit floats and
613+ # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
614+ # as an alias of npy_complex128.
615+ #
616+ # In any case, the get/set real/imag functions defined above will always work for
617+ # npy_complex64 and npy_complex128.
574618 template = """
575619 struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
576620 typedef pytensor_complex%(nbits)s complex_type;
@@ -719,7 +763,7 @@ def c_init_code(self, **kwargs):
719763 return ["import_array();" ]
720764
721765 def c_code_cache_version (self ):
722- return (15 , np .version .git_revision )
766+ return (18 , np .version .git_revision )
723767
724768 def get_shape_info (self , obj ):
725769 return obj .itemsize
0 commit comments