Skip to content

Commit 86bfdc4

Browse files
committed
Hardcode common Op parametrizations to allow numba caching
1 parent 0b56ed9 commit 86bfdc4

File tree

7 files changed

+296
-36
lines changed

7 files changed

+296
-36
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
4040
from pytensor.tensor.slinalg import Solve
4141
from pytensor.tensor.type import TensorType
42-
from pytensor.tensor.type_other import MakeSlice, NoneConst
42+
from pytensor.tensor.type_other import MakeSlice, NoneConst, NoneTypeT
4343

4444

4545
def global_numba_func(func):
@@ -75,7 +75,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
7575
message=(
7676
"(\x1b\\[1m)*" # ansi escape code for bold text
7777
"Cannot cache compiled function "
78-
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
78+
'"(store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
7979
"as it uses dynamic globals"
8080
),
8181
category=NumbaWarning,
@@ -477,6 +477,37 @@ def reshape(x, shape):
477477

478478
@numba_funcify.register(SpecifyShape)
479479
def numba_funcify_SpecifyShape(op, node, **kwargs):
480+
x, *shape = node.inputs
481+
ndim = x.type.ndim
482+
specified_dims = tuple(not isinstance(dim.type, NoneTypeT) for dim in shape)
483+
match (ndim, specified_dims):
484+
case (1, (True,)):
485+
486+
def func(x, shape_0):
487+
assert x.shape[0] == shape_0
488+
return x
489+
case (2, (True, False)):
490+
491+
def func(x, shape_0, shape_1):
492+
assert x.shape[0] == shape_0
493+
return x
494+
case (2, (False, True)):
495+
496+
def func(x, shape_0, shape_1):
497+
assert x.shape[1] == shape_1
498+
return x
499+
case (2, (True, True)):
500+
501+
def func(x, shape_0, shape_1):
502+
assert x.shape[0] == shape_0
503+
assert x.shape[1] == shape_1
504+
return x
505+
case _:
506+
func = None
507+
508+
if func is not None:
509+
return numba_njit(func)
510+
480511
shape_inputs = node.inputs[1:]
481512
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
482513

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):
411411

412412

413413
@numba_funcify.register(DimShuffle)
414-
def numba_funcify_DimShuffle(op, node, **kwargs):
414+
def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
415+
if op.is_left_expand_dims and op.new_order.count("x") == 1:
416+
# Most common case, numba compiles it more quickly
417+
@numba_njit
418+
def left_expand_dims(x):
419+
return np.expand_dims(x, 0)
420+
421+
return left_expand_dims
422+
415423
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
416424
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
417425
new_order = tuple(op._new_order)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,64 @@ def {binary_op_name}({input_signature}):
172172

173173
@numba_funcify.register(Add)
174174
def numba_funcify_Add(op, node, **kwargs):
175+
match len(node.inputs):
176+
case 2:
177+
178+
def add(i0, i1):
179+
return i0 + i1
180+
case 3:
181+
182+
def add(i0, i1, i2):
183+
return i0 + i1 + i2
184+
case 4:
185+
186+
def add(i0, i1, i2, i3):
187+
return i0 + i1 + i2 + i3
188+
case 5:
189+
190+
def add(i0, i1, i2, i3, i4):
191+
return i0 + i1 + i2 + i3 + i4
192+
case _:
193+
add = None
194+
195+
if add is not None:
196+
return numba_basic.numba_njit(add)
197+
175198
signature = create_numba_signature(node, force_scalar=True)
176199
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
177200

178-
return numba_basic.numba_njit(signature)(nary_add_fn)
201+
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)
179202

180203

181204
@numba_funcify.register(Mul)
182205
def numba_funcify_Mul(op, node, **kwargs):
206+
match len(node.inputs):
207+
case 2:
208+
209+
def mul(i0, i1):
210+
return i0 * i1
211+
case 3:
212+
213+
def mul(i0, i1, i2):
214+
return i0 * i1 * i2
215+
case 4:
216+
217+
def mul(i0, i1, i2, i3):
218+
return i0 * i1 * i2 * i3
219+
case 5:
220+
221+
def mul(i0, i1, i2, i3, i4):
222+
return i0 * i1 * i2 * i3 * i4
223+
case _:
224+
mul = None
225+
226+
if mul is not None:
227+
return numba_basic.numba_njit(mul)
228+
183229
signature = create_numba_signature(node, force_scalar=True)
184230
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
185231

186-
return numba_basic.numba_njit(signature)(nary_add_fn)
232+
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)
187233

188234

189235
@numba_funcify.register(Cast)
@@ -233,7 +279,7 @@ def numba_funcify_Composite(op, node, **kwargs):
233279

234280
_ = kwargs.pop("storage_map", None)
235281

236-
composite_fn = numba_basic.numba_njit(signature)(
282+
composite_fn = numba_basic.numba_njit(signature, cache=False)(
237283
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
238284
)
239285
return composite_fn

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,60 @@
2323
def numba_funcify_default_subtensor(op, node, **kwargs):
2424
"""Create a Python function that assembles and uses an index on an array."""
2525

26+
if isinstance(op, Subtensor) and len(op.idx_list) == 1:
27+
# Hard code indices along first dimension to allow caching
28+
[idx] = op.idx_list
29+
30+
if isinstance(idx, slice):
31+
slice_info = (
32+
idx.start is not None,
33+
idx.stop is not None,
34+
idx.step is not None,
35+
)
36+
match slice_info:
37+
case (False, False, False):
38+
39+
def subtensor(x):
40+
return x
41+
42+
case (True, False, False):
43+
44+
def subtensor(x, start):
45+
return x[start:]
46+
case (False, True, False):
47+
48+
def subtensor(x, stop):
49+
return x[:stop]
50+
case (False, False, True):
51+
52+
def subtensor(x, step):
53+
return x[::step]
54+
55+
case (True, True, False):
56+
57+
def subtensor(x, start, stop):
58+
return x[start:stop]
59+
case (True, False, True):
60+
61+
def subtensor(x, start, step):
62+
return x[start::step]
63+
case (False, True, True):
64+
65+
def subtensor(x, stop, step):
66+
return x[:stop:step]
67+
68+
case (True, True, True):
69+
70+
def subtensor(x, start, stop, step):
71+
return x[start:stop:step]
72+
73+
else:
74+
75+
def subtensor(x, i):
76+
return np.asarray(x[i])
77+
78+
return numba_njit(subtensor)
79+
2680
unique_names = unique_name_generator(
2781
["subtensor", "incsubtensor", "z"], suffix_sep="_"
2882
)
@@ -100,7 +154,7 @@ def {function_name}({", ".join(input_names)}):
100154
function_name=function_name,
101155
global_env=globals() | {"np": np},
102156
)
103-
return numba_njit(func, boundscheck=True)
157+
return numba_njit(func, boundscheck=True, cache=False)
104158

105159

106160
@numba_funcify.register(AdvancedSubtensor)
@@ -294,7 +348,9 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
294348
if broadcast:
295349

296350
@numba_njit(boundscheck=True)
297-
def advancedincsubtensor1_inplace(x, val, idxs):
351+
def advanced_incsubtensor1(x, val, idxs):
352+
out = x if inplace else x.copy()
353+
298354
if val.ndim == x.ndim:
299355
core_val = val[0]
300356
elif val.ndim == 0:
@@ -304,24 +360,28 @@ def advancedincsubtensor1_inplace(x, val, idxs):
304360
core_val = val
305361

306362
for idx in idxs:
307-
x[idx] = core_val
308-
return x
363+
out[idx] = core_val
364+
return out
309365

310366
else:
311367

312368
@numba_njit(boundscheck=True)
313-
def advancedincsubtensor1_inplace(x, vals, idxs):
369+
def advanced_incsubtensor1(x, vals, idxs):
370+
out = x if inplace else x.copy()
371+
314372
if not len(idxs) == len(vals):
315373
raise ValueError("The number of indices and values must match.")
316374
# no strict argument because incompatible with numba
317375
for idx, val in zip(idxs, vals): # noqa: B905
318-
x[idx] = val
319-
return x
376+
out[idx] = val
377+
return out
320378
else:
321379
if broadcast:
322380

323381
@numba_njit(boundscheck=True)
324-
def advancedincsubtensor1_inplace(x, val, idxs):
382+
def advanced_incsubtensor1(x, val, idxs):
383+
out = x if inplace else x.copy()
384+
325385
if val.ndim == x.ndim:
326386
core_val = val[0]
327387
elif val.ndim == 0:
@@ -331,29 +391,21 @@ def advancedincsubtensor1_inplace(x, val, idxs):
331391
core_val = val
332392

333393
for idx in idxs:
334-
x[idx] += core_val
335-
return x
394+
out[idx] += core_val
395+
return out
336396

337397
else:
338398

339399
@numba_njit(boundscheck=True)
340-
def advancedincsubtensor1_inplace(x, vals, idxs):
400+
def advanced_incsubtensor1(x, vals, idxs):
401+
out = x if inplace else x.copy()
402+
341403
if not len(idxs) == len(vals):
342404
raise ValueError("The number of indices and values must match.")
343405
# no strict argument because unsupported by numba
344406
# TODO: this doesn't come up in tests
345407
for idx, val in zip(idxs, vals): # noqa: B905
346-
x[idx] += val
347-
return x
348-
349-
if inplace:
350-
return advancedincsubtensor1_inplace
351-
352-
else:
353-
354-
@numba_njit
355-
def advancedincsubtensor1(x, vals, idxs):
356-
x = x.copy()
357-
return advancedincsubtensor1_inplace(x, vals, idxs)
408+
out[idx] += val
409+
return out
358410

359-
return advancedincsubtensor1
411+
return advanced_incsubtensor1

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,36 @@ def allocempty({", ".join(shape_var_names)}):
5858

5959
@numba_funcify.register(Alloc)
6060
def numba_funcify_Alloc(op, node, **kwargs):
61+
x, *shape = node.inputs
62+
if all(x.type.broadcastable):
63+
match len(shape):
64+
case 1:
65+
66+
def alloc(val, dim0):
67+
shape = (dim0.item(),)
68+
res = np.empty(shape, dtype=val.dtype)
69+
res[...] = val
70+
return res
71+
case 2:
72+
73+
def alloc(val, dim0, dim1):
74+
shape = (dim0.item(), dim1.item())
75+
res = np.empty(shape, dtype=val.dtype)
76+
res[...] = val
77+
return res
78+
case 3:
79+
80+
def alloc(val, dim0, dim1, dim2):
81+
shape = (dim0.item(), dim1.item(), dim2.item())
82+
res = np.empty(shape, dtype=val.dtype)
83+
res[...] = val
84+
return res
85+
case _:
86+
alloc = None
87+
88+
if alloc is not None:
89+
return numba_basic.numba_njit(alloc)
90+
6191
global_env = {"np": np, "to_scalar": numba_basic.to_scalar}
6292

6393
unique_names = unique_name_generator(
@@ -68,7 +98,7 @@ def numba_funcify_Alloc(op, node, **kwargs):
6898
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
6999
shapes_to_items_src = indent(
70100
"\n".join(
71-
f"{item_name} = to_scalar({shape_name})"
101+
f"{item_name} = {shape_name}.item()"
72102
for item_name, shape_name in zip(
73103
shape_var_item_names, shape_var_names, strict=True
74104
)
@@ -86,12 +116,11 @@ def numba_funcify_Alloc(op, node, **kwargs):
86116

87117
alloc_def_src = f"""
88118
def alloc(val, {", ".join(shape_var_names)}):
89-
val_np = np.asarray(val)
90119
{shapes_to_items_src}
91120
scalar_shape = {create_tuple_string(shape_var_item_names)}
92121
{check_runtime_broadcast_src}
93-
res = np.empty(scalar_shape, dtype=val_np.dtype)
94-
res[...] = val_np
122+
res = np.empty(scalar_shape, dtype=val.dtype)
123+
res[...] = val
95124
return res
96125
"""
97126
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})

0 commit comments

Comments
 (0)