11from collections .abc import Sequence
2- from copy import copy
32from typing import Any , cast
43
54import numpy as np
@@ -79,7 +78,6 @@ def __init__(
7978 self .name = name
8079 self .inputs_sig , self .outputs_sig = _parse_gufunc_signature (signature )
8180 self .gufunc_spec = gufunc_spec
82- self ._gufunc = None
8381 if destroy_map is not None :
8482 self .destroy_map = destroy_map
8583 if self .destroy_map != core_op .destroy_map :
@@ -91,11 +89,6 @@ def __init__(
9189
9290 super ().__init__ (** kwargs )
9391
94- def __getstate__ (self ):
95- d = copy (self .__dict__ )
96- d ["_gufunc" ] = None
97- return d
98-
9992 def _create_dummy_core_node (self , inputs : Sequence [TensorVariable ]) -> Apply :
10093 core_input_types = []
10194 for i , (inp , sig ) in enumerate (zip (inputs , self .inputs_sig )):
@@ -320,8 +313,7 @@ def core_func(*inner_inputs):
320313 else :
321314 return tuple (r [0 ] for r in inner_outputs )
322315
323- self ._gufunc = np .vectorize (core_func , signature = self .signature )
324- return self ._gufunc
316+ node .tag .gufunc = np .vectorize (core_func , signature = self .signature )
325317
326318 def _check_runtime_broadcast (self , node , inputs ):
327319 batch_ndim = self .batch_ndim (node )
@@ -340,10 +332,12 @@ def _check_runtime_broadcast(self, node, inputs):
340332 )
341333
342334 def perform (self , node , inputs , output_storage ):
343- gufunc = self . _gufunc
335+ gufunc = getattr ( node . tag , "gufunc" , None )
344336
345337 if gufunc is None :
346- gufunc = self ._create_gufunc (node )
338+ # Cache it once per node
339+ self ._create_gufunc (node )
340+ gufunc = node .tag .gufunc
347341
348342 self ._check_runtime_broadcast (node , inputs )
349343
0 commit comments