Skip to content

Commit 8673fc4

Browse files
committed
Numba Blockwise: Force scalar inner inputs to be arrays
1 parent c8d7763 commit 8673fc4

File tree

5 files changed

+64
-10
lines changed

5 files changed

+64
-10
lines changed

pytensor/link/numba/dispatch/blockwise.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def impl(*inputs_and_core_shapes):
8686
output_bc_patterns,
8787
output_dtypes,
8888
inplace_pattern,
89+
False, # allow_core_scalar
8990
(), # constant_inputs
9091
inputs,
9192
tuple_core_shapes,
@@ -98,6 +99,7 @@ def impl(*inputs_and_core_shapes):
9899
# If the core op cannot be cached, the Blockwise wrapper cannot be cached either
99100
blockwise_key = None
100101
else:
102+
blockwise_cache_version = 1
101103
blockwise_key = "_".join(
102104
map(
103105
str,
@@ -108,6 +110,7 @@ def impl(*inputs_and_core_shapes):
108110
blockwise_op.signature,
109111
input_bc_patterns,
110112
core_op_key,
113+
blockwise_cache_version,
111114
),
112115
)
113116
)

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def impl(*inputs):
365365
output_bc_patterns_enc,
366366
output_dtypes_enc,
367367
inplace_pattern_enc,
368+
True, # allow_core_scalar
368369
(), # constant_inputs
369370
inputs,
370371
core_output_shapes, # core_shapes

pytensor/link/numba/dispatch/random.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def impl(core_shape, rng, size, *dist_params):
470470
output_bc_patterns,
471471
output_dtypes,
472472
inplace_pattern,
473+
True, # allow_core_scalar
473474
(rng,),
474475
dist_params,
475476
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _vectorized(
8282
output_bc_patterns,
8383
output_dtypes,
8484
inplace_pattern,
85+
allow_core_scalar,
8586
constant_inputs_types,
8687
input_types,
8788
output_core_shape_types,
@@ -93,6 +94,7 @@ def _vectorized(
9394
output_bc_patterns,
9495
output_dtypes,
9596
inplace_pattern,
97+
allow_core_scalar,
9698
constant_inputs_types,
9799
input_types,
98100
output_core_shape_types,
@@ -119,6 +121,10 @@ def _vectorized(
119121
inplace_pattern = inplace_pattern.literal_value
120122
inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode()))
121123

124+
if not isinstance(allow_core_scalar, types.Literal):
125+
raise TypeError("allow_core_scalar must be literal.")
126+
allow_core_scalar = allow_core_scalar.literal_value
127+
122128
batch_ndim = len(input_bc_patterns[0])
123129
nin = len(constant_inputs_types) + len(input_types)
124130
nout = len(output_bc_patterns)
@@ -142,8 +148,7 @@ def _vectorized(
142148
core_input_types = []
143149
for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True):
144150
core_ndim = input_type.ndim - len(bc_pattern)
145-
# TODO: Reconsider this
146-
if core_ndim == 0:
151+
if allow_core_scalar and core_ndim == 0:
147152
core_input_type = input_type.dtype
148153
else:
149154
core_input_type = types.Array(
@@ -196,7 +201,7 @@ def codegen(
196201
sig,
197202
args,
198203
):
199-
[_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
204+
[_, _, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args
200205

201206
constant_inputs = cgutils.unpack_tuple(builder, constant_inputs)
202207
inputs = cgutils.unpack_tuple(builder, inputs)
@@ -256,6 +261,7 @@ def codegen(
256261
output_bc_patterns_val,
257262
input_types,
258263
output_types,
264+
core_scalar=allow_core_scalar,
259265
)
260266

261267
if len(outputs) == 1:
@@ -429,6 +435,7 @@ def make_loop_call(
429435
output_bc: tuple[tuple[bool, ...], ...],
430436
input_types: tuple[Any, ...],
431437
output_types: tuple[Any, ...],
438+
core_scalar: bool = True,
432439
):
433440
safe = (False, False)
434441

@@ -486,7 +493,7 @@ def make_loop_call(
486493
idxs_bc,
487494
*safe,
488495
)
489-
if core_ndim == 0:
496+
if core_scalar and core_ndim == 0:
490497
# Retrive scalar item at index
491498
val = builder.load(ptr)
492499
# val.set_metadata("alias.scope", input_scope_set)
@@ -499,15 +506,19 @@ def make_loop_call(
499506
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
500507
)
501508
core_array = context.make_array(core_arry_type)(context, builder)
502-
core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:]
503-
core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:]
509+
core_shape = cgutils.unpack_tuple(builder, input.shape)[
510+
input_type.ndim - core_ndim :
511+
]
512+
core_strides = cgutils.unpack_tuple(builder, input.strides)[
513+
input_type.ndim - core_ndim :
514+
]
504515
itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype))
505516
context.populate_array(
506517
core_array,
507518
# TODO whey do we need to bitcast?
508519
data=builder.bitcast(ptr, core_array.data.type),
509-
shape=cgutils.pack_array(builder, core_shape),
510-
strides=cgutils.pack_array(builder, core_strides),
520+
shape=core_shape,
521+
strides=core_strides,
511522
itemsize=context.get_constant(types.intp, itemsize),
512523
# TODO what is meminfo about?
513524
meminfo=None,

tests/link/numba/test_blockwise.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import pytest
33

44
from pytensor import function
5-
from pytensor.tensor import lvector, tensor, tensor3
5+
from pytensor.graph import Apply
6+
from pytensor.scalar import ScalarOp
7+
from pytensor.tensor import TensorVariable, lvector, tensor, tensor3, vector
68
from pytensor.tensor.basic import Alloc, ARange, constant
79
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
8-
from pytensor.tensor.elemwise import DimShuffle
10+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
911
from pytensor.tensor.nlinalg import SVD, Det
1012
from pytensor.tensor.slinalg import Cholesky, cholesky
1113
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
@@ -90,3 +92,39 @@ def test_blockwise_scalar_dimshuffle():
9092
)
9193
out = blockwise_scalar_ds(x)
9294
compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False)
95+
96+
97+
def test_blockwise_vs_elemwise_scalar_op():
98+
# Regression test for https://github.com/pymc-devs/pytensor/issues/1760
99+
100+
class TestScalarOp(ScalarOp):
101+
def make_node(self, x):
102+
return Apply(self, [x], [x.type()])
103+
104+
def perform(self, node, inputs, outputs):
105+
[x] = inputs
106+
if isinstance(node.inputs[0], TensorVariable):
107+
assert isinstance(x, np.ndarray)
108+
else:
109+
assert isinstance(x, np.number | float)
110+
out = x + 1
111+
if isinstance(node.outputs[0], TensorVariable):
112+
out = np.asarray(out)
113+
outputs[0][0] = out
114+
115+
x = vector("x")
116+
y = Elemwise(TestScalarOp())(x)
117+
with pytest.warns(
118+
UserWarning,
119+
match="Numba will use object mode to run TestScalarOp's perform method",
120+
):
121+
fn = function([x], y, mode="NUMBA")
122+
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])
123+
124+
z = Blockwise(TestScalarOp(), signature="()->()")(x)
125+
with pytest.warns(
126+
UserWarning,
127+
match="Numba will use object mode to run TestScalarOp's perform method",
128+
):
129+
fn = function([x], z, mode="NUMBA")
130+
np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])

0 commit comments

Comments
 (0)