1+ import warnings
12from collections .abc import Sequence
23from copy import copy
34from textwrap import dedent
1920from pytensor .misc .frozendict import frozendict
2021from pytensor .printing import Printer , pprint
2122from pytensor .scalar import get_scalar_type
23+ from pytensor .scalar .basic import Composite , transfer_type , upcast
2224from pytensor .scalar .basic import bool as scalar_bool
2325from pytensor .scalar .basic import identity as scalar_identity
24- from pytensor .scalar .basic import transfer_type , upcast
2526from pytensor .tensor import elemwise_cgen as cgen
2627from pytensor .tensor import get_vector_length
2728from pytensor .tensor .basic import _get_vector_length , as_tensor_variable
@@ -364,6 +365,7 @@ def __init__(
364365 self .name = name
365366 self .scalar_op = scalar_op
366367 self .inplace_pattern = inplace_pattern
368+ self .ufunc = None
367369 self .destroy_map = {o : [i ] for o , i in self .inplace_pattern .items ()}
368370
369371 if nfunc_spec is None :
@@ -375,14 +377,12 @@ def __init__(
375377 def __getstate__ (self ):
376378 d = copy (self .__dict__ )
377379 d .pop ("ufunc" )
378- d .pop ("nfunc" )
379- d .pop ("__epydoc_asRoutine" , None )
380380 return d
381381
382382 def __setstate__ (self , d ):
383+ d .pop ("nfunc" , None ) # This used to be stored in the Op, not anymore
383384 super ().__setstate__ (d )
384385 self .ufunc = None
385- self .nfunc = None
386386 self .inplace_pattern = frozendict (self .inplace_pattern )
387387
388388 def get_output_info (self , * inputs ):
@@ -623,31 +623,49 @@ def transform(r):
623623
624624 return ret
625625
626- def prepare_node (self , node , storage_map , compute_map , impl ):
627- # Postpone the ufunc building to the last minutes due to:
628- # - NumPy ufunc support only up to 32 operands (inputs and outputs)
629- # But our c code support more.
630- # - nfunc is reused for scipy and scipy is optional
631- if (len (node .inputs ) + len (node .outputs )) > 32 and impl == "py" :
632- impl = "c"
633-
634- if getattr (self , "nfunc_spec" , None ) and impl != "c" :
635- self .nfunc = import_func_from_string (self .nfunc_spec [0 ])
636-
626+ def _create_node_ufunc (self , node ) -> None :
637627 if (
638- ( len ( node . inputs ) + len ( node . outputs )) <= 32
639- and ( self . nfunc is None or self . scalar_op . nin != len ( node . inputs ))
640- and self . ufunc is None
641- and impl == "py"
628+ self . nfunc_spec is not None
629+ # Some scalar Ops like `Add` allow for a variable number of inputs,
630+ # whereas the numpy counterpart does not.
631+ and len ( node . inputs ) == self . nfunc_spec [ 1 ]
642632 ):
633+ # Do we really need to cache this import in the Op?
634+ # If it's so costly, just memorize `import_func_from_string`
635+ ufunc = import_func_from_string (self .nfunc_spec [0 ])
636+ if ufunc is None :
637+ raise ValueError (
638+ f"Could not import ufunc { self .nfunc_spec [0 ]} for { self } "
639+ )
640+
641+ elif self .ufunc is not None :
642+ # Cached before
643+ ufunc = self .ufunc
644+
645+ else :
646+ if (len (node .inputs ) + len (node .outputs )) > 32 :
647+ if isinstance (self .scalar_op , Composite ):
648+ warnings .warn (
649+ "Trying to create a Python Composite Elemwise function with more than 32 operands.\n "
650+ "This operation should not have been introduced if the C-backend is not properly setup. "
651+ 'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n '
652+ "Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
653+ '`pytensor.config.mode = "NUMBA" (or "JAX").'
654+ )
655+ else :
656+ warnings .warn (
657+ f"Trying to create a Python Elemwise function for the scalar Op { self .scalar_op } "
658+ f"with more than 32 operands. This will likely fail."
659+ )
660+
643661 ufunc = np .frompyfunc (
644662 self .scalar_op .impl , len (node .inputs ), self .scalar_op .nout
645663 )
646- if self .scalar_op .nin > 0 :
647- # We can reuse it for many nodes
664+ if self .scalar_op .nin > 0 : # Default in base class is -1
665+ # Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
648666 self .ufunc = ufunc
649- else :
650- node .tag .ufunc = ufunc
667+
668+ node .tag .ufunc = ufunc
651669
652670 # Numpy ufuncs will sometimes perform operations in
653671 # float16, in particular when the input is int8.
@@ -669,6 +687,11 @@ def prepare_node(self, node, storage_map, compute_map, impl):
669687 char = np .sctype2char (out_dtype )
670688 sig = char * node .nin + "->" + char * node .nout
671689 node .tag .sig = sig
690+
691+ def prepare_node (self , node , storage_map , compute_map , impl ):
692+ if impl == "py" :
693+ self ._create_node_ufunc (node )
694+
672695 node .tag .fake_node = Apply (
673696 self .scalar_op ,
674697 [
@@ -684,71 +707,36 @@ def prepare_node(self, node, storage_map, compute_map, impl):
684707 self .scalar_op .prepare_node (node .tag .fake_node , None , None , impl )
685708
686709 def perform (self , node , inputs , output_storage ):
687- if (len (node .inputs ) + len (node .outputs )) > 32 :
688- # Some versions of NumPy will segfault, other will raise a
689- # ValueError, if the number of operands in an ufunc is more than 32.
690- # In that case, the C version should be used, or Elemwise fusion
691- # should be disabled.
692- # FIXME: This no longer calls the C implementation!
693- super ().perform (node , inputs , output_storage )
710+ ufunc = getattr (node .tag , "ufunc" , None )
711+ if ufunc is None :
712+ self ._create_node_ufunc (node )
713+ ufunc = node .tag .ufunc
694714
695715 self ._check_runtime_broadcast (node , inputs )
696716
697- ufunc_args = inputs
698717 ufunc_kwargs = {}
699- # We supported in the past calling manually op.perform.
700- # To keep that support we need to sometimes call self.prepare_node
701- if self .nfunc is None and self .ufunc is None :
702- self .prepare_node (node , None , None , "py" )
703- if self .nfunc and len (inputs ) == self .nfunc_spec [1 ]:
704- ufunc = self .nfunc
705- nout = self .nfunc_spec [2 ]
706- if hasattr (node .tag , "sig" ):
707- ufunc_kwargs ["sig" ] = node .tag .sig
708- # Unfortunately, the else case does not allow us to
709- # directly feed the destination arguments to the nfunc
710- # since it sometimes requires resizing. Doing this
711- # optimization is probably not worth the effort, since we
712- # should normally run the C version of the Op.
713- else :
714- # the second calling form is used because in certain versions of
715- # numpy the first (faster) version leads to segfaults
716- if self .ufunc :
717- ufunc = self .ufunc
718- elif not hasattr (node .tag , "ufunc" ):
719- # It happen that make_thunk isn't called, like in
720- # get_underlying_scalar_constant_value
721- self .prepare_node (node , None , None , "py" )
722- # prepare_node will add ufunc to self or the tag
723- # depending if we can reuse it or not. So we need to
724- # test both again.
725- if self .ufunc :
726- ufunc = self .ufunc
727- else :
728- ufunc = node .tag .ufunc
729- else :
730- ufunc = node .tag .ufunc
731-
732- nout = ufunc .nout
718+ if hasattr (node .tag , "sig" ):
719+ ufunc_kwargs ["sig" ] = node .tag .sig
733720
734- variables = ufunc (* ufunc_args , ** ufunc_kwargs )
721+ outputs = ufunc (* inputs , ** ufunc_kwargs )
735722
736- if nout == 1 :
737- variables = [ variables ]
723+ if not isinstance ( outputs , tuple ) :
724+ outputs = ( outputs ,)
738725
739- for i , (variable , storage , nout ) in enumerate (
740- zip (variables , output_storage , node .outputs )
726+ for i , (out , out_storage , node_out ) in enumerate (
727+ zip (outputs , output_storage , node .outputs )
741728 ):
742- storage [0 ] = variable = np .asarray (variable , dtype = nout .dtype )
729+ # Numpy frompyfunc always returns object arrays
730+ out_storage [0 ] = out = np .asarray (out , dtype = node_out .dtype )
743731
744732 if i in self .inplace_pattern :
745- odat = inputs [self .inplace_pattern [i ]]
746- odat [...] = variable
747- storage [0 ] = odat
733+ inp = inputs [self .inplace_pattern [i ]]
734+ inp [...] = out
735+ out_storage [0 ] = inp
748736
749737 # numpy.real return a view!
750- if not variable .flags .owndata :
751- storage [0 ] = variable .copy ()
738+ if not out .flags .owndata :
739+ out_storage [0 ] = out .copy ()
752740
753741 @staticmethod
754742 def _check_runtime_broadcast (node , inputs ):
0 commit comments