11from collections .abc import Sequence
2- from copy import copy
32from typing import Any , cast
43
54import numpy as np
@@ -79,7 +78,6 @@ def __init__(
7978 self .name = name
8079 self .inputs_sig , self .outputs_sig = _parse_gufunc_signature (signature )
8180 self .gufunc_spec = gufunc_spec
82- self ._gufunc = None
8381 if destroy_map is not None :
8482 self .destroy_map = destroy_map
8583 if self .destroy_map != core_op .destroy_map :
@@ -91,11 +89,6 @@ def __init__(
9189
9290 super ().__init__ (** kwargs )
9391
94- def __getstate__ (self ):
95- d = copy (self .__dict__ )
96- d ["_gufunc" ] = None
97- return d
98-
9992 def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
10093 core_input_types = []
10194 for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
@@ -296,32 +289,40 @@ def L_op(self, inputs, outs, ograds):
296289
297290 return rval
298291
299- def _create_gufunc (self , node ):
292+ def _create_node_gufunc (self , node ) -> None :
293+ """Define (or retrieve) the node gufunc used in `perform`.
294+
295+ If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
296+ Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
297+
298+ The gufunc is stored in the tag of the node.
299+ """
300300 gufunc_spec = self .gufunc_spec or getattr (self .core_op , "gufunc_spec" , None )
301301
302302 if gufunc_spec is not None :
303- self ._gufunc = import_func_from_string (gufunc_spec [0 ])
304- if self ._gufunc :
305- return self ._gufunc
306- else :
303+ gufunc = import_func_from_string (gufunc_spec [0 ])
304+ if gufunc is None :
307305 raise ValueError (f"Could not import gufunc { gufunc_spec [0 ]} for { self } " )
308306
309- n_outs = len (self .outputs_sig )
310- core_node = self ._create_dummy_core_node (node .inputs )
307+ else :
308+ # Wrap core_op perform method in numpy vectorize
309+ n_outs = len (self .outputs_sig )
310+ core_node = self ._create_dummy_core_node (node .inputs )
311311
312- def core_func (* inner_inputs ):
313- inner_outputs = [[None ] for _ in range (n_outs )]
312+ def core_func (* inner_inputs ):
313+ inner_outputs = [[None ] for _ in range (n_outs )]
314314
315- inner_inputs = [np .asarray (inp ) for inp in inner_inputs ]
316- self .core_op .perform (core_node , inner_inputs , inner_outputs )
315+ inner_inputs = [np .asarray (inp ) for inp in inner_inputs ]
316+ self .core_op .perform (core_node , inner_inputs , inner_outputs )
317317
318- if len (inner_outputs ) == 1 :
319- return inner_outputs [0 ][0 ]
320- else :
321- return tuple (r [0 ] for r in inner_outputs )
318+ if len (inner_outputs ) == 1 :
319+ return inner_outputs [0 ][0 ]
320+ else :
321+ return tuple (r [0 ] for r in inner_outputs )
322+
323+ gufunc = np .vectorize (core_func , signature = self .signature )
322324
323- self ._gufunc = np .vectorize (core_func , signature = self .signature )
324- return self ._gufunc
325+ node .tag .gufunc = gufunc
325326
326327 def _check_runtime_broadcast (self , node , inputs ):
327328 batch_ndim = self .batch_ndim (node )
@@ -340,10 +341,12 @@ def _check_runtime_broadcast(self, node, inputs):
340341 )
341342
342343 def perform (self , node , inputs , output_storage ):
343- gufunc = self . _gufunc
344+ gufunc = getattr ( node . tag , "gufunc" , None )
344345
345346 if gufunc is None :
346- gufunc = self ._create_gufunc (node )
347+ # Cache it once per node
348+ self ._create_node_gufunc (node )
349+ gufunc = node .tag .gufunc
347350
348351 self ._check_runtime_broadcast (node , inputs )
349352
0 commit comments