Skip to content

Commit 50b4c4b

Browse files
committed
Make blockwise perform method node dependent
1 parent a377c22 commit 50b4c4b

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

pytensor/tensor/blockwise.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Sequence
2-
from copy import copy
32
from typing import Any, cast
43

54
import numpy as np
@@ -79,7 +78,6 @@ def __init__(
7978
self.name = name
8079
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
8180
self.gufunc_spec = gufunc_spec
82-
self._gufunc = None
8381
if destroy_map is not None:
8482
self.destroy_map = destroy_map
8583
if self.destroy_map != core_op.destroy_map:
@@ -91,11 +89,6 @@ def __init__(
9189

9290
super().__init__(**kwargs)
9391

94-
def __getstate__(self):
95-
d = copy(self.__dict__)
96-
d["_gufunc"] = None
97-
return d
98-
9992
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
10093
core_input_types = []
10194
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
@@ -320,8 +313,7 @@ def core_func(*inner_inputs):
320313
else:
321314
return tuple(r[0] for r in inner_outputs)
322315

323-
self._gufunc = np.vectorize(core_func, signature=self.signature)
324-
return self._gufunc
316+
node.tag.gufunc = np.vectorize(core_func, signature=self.signature)
325317

326318
def _check_runtime_broadcast(self, node, inputs):
327319
batch_ndim = self.batch_ndim(node)
@@ -340,10 +332,12 @@ def _check_runtime_broadcast(self, node, inputs):
340332
)
341333

342334
def perform(self, node, inputs, output_storage):
343-
gufunc = self._gufunc
335+
gufunc = getattr(node.tag, "gufunc", None)
344336

345337
if gufunc is None:
346-
gufunc = self._create_gufunc(node)
338+
# Cache it once per node
339+
self._create_gufunc(node)
340+
gufunc = node.tag.gufunc
347341

348342
self._check_runtime_broadcast(node, inputs)
349343

tests/tensor/test_blockwise.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,41 @@
2828
from pytensor.tensor.utils import _parse_gufunc_signature
2929

3030

31+
def test_perform_method_per_node():
32+
"""Confirm that Blockwise uses one perform method per node.
33+
34+
This is important if the perform method requires node information (such as dtypes)
35+
"""
36+
37+
class NodeDependentPerformOp(Op):
38+
def make_node(self, x):
39+
return Apply(self, [x], [x.type()])
40+
41+
def perform(self, node, inputs, outputs):
42+
[x] = inputs
43+
if node.inputs[0].type.dtype.startswith("float"):
44+
y = x + 1
45+
else:
46+
y = x - 1
47+
outputs[0][0] = y
48+
49+
blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()")
50+
x = tensor("x", shape=(3,), dtype="float32")
51+
y = tensor("y", shape=(3,), dtype="int32")
52+
53+
out_x = blockwise_op(x)
54+
out_y = blockwise_op(y)
55+
fn = pytensor.function([x, y], [out_x, out_y])
56+
[op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes]
57+
# Confirm both nodes have the same Op
58+
assert op1 is blockwise_op
59+
assert op1 is op2
60+
61+
res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32"))
62+
np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32"))
63+
np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32"))
64+
65+
3166
def test_vectorize_blockwise():
3267
mat = tensor(shape=(None, None))
3368
tns = tensor(shape=(None, None, None))

0 commit comments

Comments
 (0)