Skip to content

Commit dc22852

Browse files
committed
Hardcode common Op parametrizations to allow numba caching
1 parent 17ef6a4 commit dc22852

File tree

5 files changed

+230
-30
lines changed

5 files changed

+230
-30
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):
411411

412412

413413
@numba_funcify.register(DimShuffle)
414-
def numba_funcify_DimShuffle(op, node, **kwargs):
414+
def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
415+
if op.is_left_expand_dims and op.new_order.count("x") == 1:
416+
# Most common case, numba compiles it more quickly
417+
@numba_njit
418+
def left_expand_dims(x):
419+
return np.expand_dims(x, 0)
420+
421+
return left_expand_dims
422+
415423
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
416424
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
417425
new_order = tuple(op._new_order)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,18 +188,64 @@ def {binary_op_name}({input_signature}):
188188

189189
@numba_funcify.register(Add)
190190
def numba_funcify_Add(op, node, **kwargs):
191+
match len(node.inputs):
192+
case 2:
193+
194+
def add(i0, i1):
195+
return i0 + i1
196+
case 3:
197+
198+
def add(i0, i1, i2):
199+
return i0 + i1 + i2
200+
case 4:
201+
202+
def add(i0, i1, i2, i3):
203+
return i0 + i1 + i2 + i3
204+
case 5:
205+
206+
def add(i0, i1, i2, i3, i4):
207+
return i0 + i1 + i2 + i3 + i4
208+
case _:
209+
add = None
210+
211+
if add is not None:
212+
return numba_basic.numba_njit(add)
213+
191214
signature = create_numba_signature(node, force_scalar=True)
192215
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
193216

194-
return numba_basic.numba_njit(signature)(nary_add_fn)
217+
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)
195218

196219

197220
@numba_funcify.register(Mul)
198221
def numba_funcify_Mul(op, node, **kwargs):
222+
match len(node.inputs):
223+
case 2:
224+
225+
def mul(i0, i1):
226+
return i0 * i1
227+
case 3:
228+
229+
def mul(i0, i1, i2):
230+
return i0 * i1 * i2
231+
case 4:
232+
233+
def mul(i0, i1, i2, i3):
234+
return i0 * i1 * i2 * i3
235+
case 5:
236+
237+
def mul(i0, i1, i2, i3, i4):
238+
return i0 * i1 * i2 * i3 * i4
239+
case _:
240+
mul = None
241+
242+
if mul is not None:
243+
return numba_basic.numba_njit(mul)
244+
199245
signature = create_numba_signature(node, force_scalar=True)
200246
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
201247

202-
return numba_basic.numba_njit(signature)(nary_add_fn)
248+
return numba_basic.numba_njit(signature, cache=False)(nary_add_fn)
203249

204250

205251
@numba_funcify.register(Cast)
@@ -249,7 +295,7 @@ def numba_funcify_Composite(op, node, **kwargs):
249295

250296
_ = kwargs.pop("storage_map", None)
251297

252-
composite_fn = numba_basic.numba_njit(signature)(
298+
composite_fn = numba_basic.numba_njit(signature, cache=False)(
253299
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
254300
)
255301
return composite_fn

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,60 @@
2323
def numba_funcify_default_subtensor(op, node, **kwargs):
2424
"""Create a Python function that assembles and uses an index on an array."""
2525

26+
if isinstance(op, Subtensor) and len(op.idx_list) == 1:
27+
# Hard code indices along first dimension to allow caching
28+
[idx] = op.idx_list
29+
30+
if isinstance(idx, slice):
31+
slice_info = (
32+
idx.start is not None,
33+
idx.stop is not None,
34+
idx.step is not None,
35+
)
36+
match slice_info:
37+
case (False, False, False):
38+
39+
def subtensor(x):
40+
return x
41+
42+
case (True, False, False):
43+
44+
def subtensor(x, start):
45+
return x[start:]
46+
case (False, True, False):
47+
48+
def subtensor(x, stop):
49+
return x[:stop]
50+
case (False, False, True):
51+
52+
def subtensor(x, step):
53+
return x[::step]
54+
55+
case (True, True, False):
56+
57+
def subtensor(x, start, stop):
58+
return x[start:stop]
59+
case (True, False, True):
60+
61+
def subtensor(x, start, step):
62+
return x[start::step]
63+
case (False, True, True):
64+
65+
def subtensor(x, stop, step):
66+
return x[:stop:step]
67+
68+
case (True, True, True):
69+
70+
def subtensor(x, start, stop, step):
71+
return x[start:stop:step]
72+
73+
else:
74+
75+
def subtensor(x, i):
76+
return np.asarray(x[i])
77+
78+
return numba_njit(subtensor)
79+
2680
unique_names = unique_name_generator(
2781
["subtensor", "incsubtensor", "z"], suffix_sep="_"
2882
)
@@ -100,7 +154,7 @@ def {function_name}({", ".join(input_names)}):
100154
function_name=function_name,
101155
global_env=globals() | {"np": np},
102156
)
103-
return numba_njit(func, boundscheck=True)
157+
return numba_njit(func, boundscheck=True, cache=False)
104158

105159

106160
@numba_funcify.register(AdvancedSubtensor)
@@ -294,7 +348,9 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
294348
if broadcast:
295349

296350
@numba_njit(boundscheck=True)
297-
def advancedincsubtensor1_inplace(x, val, idxs):
351+
def advanced_incsubtensor1(x, val, idxs):
352+
out = x if inplace else x.copy()
353+
298354
if val.ndim == x.ndim:
299355
core_val = val[0]
300356
elif val.ndim == 0:
@@ -304,24 +360,28 @@ def advancedincsubtensor1_inplace(x, val, idxs):
304360
core_val = val
305361

306362
for idx in idxs:
307-
x[idx] = core_val
308-
return x
363+
out[idx] = core_val
364+
return out
309365

310366
else:
311367

312368
@numba_njit(boundscheck=True)
313-
def advancedincsubtensor1_inplace(x, vals, idxs):
369+
def advanced_incsubtensor1(x, vals, idxs):
370+
out = x if inplace else x.copy()
371+
314372
if not len(idxs) == len(vals):
315373
raise ValueError("The number of indices and values must match.")
316374
# no strict argument because incompatible with numba
317375
for idx, val in zip(idxs, vals): # noqa: B905
318-
x[idx] = val
319-
return x
376+
out[idx] = val
377+
return out
320378
else:
321379
if broadcast:
322380

323381
@numba_njit(boundscheck=True)
324-
def advancedincsubtensor1_inplace(x, val, idxs):
382+
def advanced_incsubtensor1(x, val, idxs):
383+
out = x if inplace else x.copy()
384+
325385
if val.ndim == x.ndim:
326386
core_val = val[0]
327387
elif val.ndim == 0:
@@ -331,29 +391,21 @@ def advancedincsubtensor1_inplace(x, val, idxs):
331391
core_val = val
332392

333393
for idx in idxs:
334-
x[idx] += core_val
335-
return x
394+
out[idx] += core_val
395+
return out
336396

337397
else:
338398

339399
@numba_njit(boundscheck=True)
340-
def advancedincsubtensor1_inplace(x, vals, idxs):
400+
def advanced_incsubtensor1(x, vals, idxs):
401+
out = x if inplace else x.copy()
402+
341403
if not len(idxs) == len(vals):
342404
raise ValueError("The number of indices and values must match.")
343405
# no strict argument because unsupported by numba
344406
# TODO: this doesn't come up in tests
345407
for idx, val in zip(idxs, vals): # noqa: B905
346-
x[idx] += val
347-
return x
348-
349-
if inplace:
350-
return advancedincsubtensor1_inplace
351-
352-
else:
353-
354-
@numba_njit
355-
def advancedincsubtensor1(x, vals, idxs):
356-
x = x.copy()
357-
return advancedincsubtensor1_inplace(x, vals, idxs)
408+
out[idx] += val
409+
return out
358410

359-
return advancedincsubtensor1
411+
return advanced_incsubtensor1

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,97 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
3535
on[...] = ton
3636
3737
"""
38+
# Hardcode some cases for numba caching
39+
match (nin, nout):
40+
case (1, 1):
41+
42+
def func(i0, o0):
43+
t0 = core_op_fn(i0)
44+
o0[...] = t0
45+
case (1, 2):
46+
47+
def func(i0, o0, o1):
48+
t0, t1 = core_op_fn(i0)
49+
o0[...] = t0
50+
o1[...] = t1
51+
case (1, 3):
52+
53+
def func(i0, o0, o1, o2):
54+
t0, t1, t2 = core_op_fn(i0)
55+
o0[...] = t0
56+
o1[...] = t1
57+
o2[...] = t2
58+
59+
case (2, 1):
60+
61+
def func(i0, i1, o0):
62+
t0 = core_op_fn(i0, i1)
63+
o0[...] = t0
64+
case (2, 2):
65+
66+
def func(i0, i1, o0, o1):
67+
t0, t1 = core_op_fn(i0, i1)
68+
o0[...] = t0
69+
o1[...] = t1
70+
case (2, 3):
71+
72+
def func(i0, i1, o0, o1, o2):
73+
t0, t1, t2 = core_op_fn(i0, i1)
74+
o0[...] = t0
75+
o1[...] = t1
76+
o2[...] = t2
77+
78+
case (3, 1):
79+
80+
def func(i0, i1, i2, o0):
81+
t0 = core_op_fn(i0, i1, i2)
82+
o0[...] = t0
83+
84+
case (3, 2):
85+
86+
def func(i0, i1, i2, o0, o1):
87+
t0, t1 = core_op_fn(i0, i1, i2)
88+
o0[...] = t0
89+
o1[...] = t1
90+
case (3, 3):
91+
92+
def func(i0, i1, i2, o0, o1, o2):
93+
t0, t1, t2 = core_op_fn(i0, i1, i2)
94+
o0[...] = t0
95+
o1[...] = t1
96+
o2[...] = t2
97+
98+
case (4, 1):
99+
100+
def func(i0, i1, i2, i3, o0):
101+
t0 = core_op_fn(i0, i1, i2, i3)
102+
o0[...] = t0
103+
104+
case (4, 2):
105+
106+
def func(i0, i1, i2, i3, o0, o1):
107+
t0, t1 = core_op_fn(i0, i1, i2, i3)
108+
o0[...] = t0
109+
o1[...] = t1
110+
111+
case (5, 1):
112+
113+
def func(i0, i1, i2, i3, i4, o0):
114+
t0 = core_op_fn(i0, i1, i2, i3, i4)
115+
o0[...] = t0
116+
117+
case (5, 2):
118+
119+
def func(i0, i1, i2, i3, i4, o0, o1):
120+
t0, t1 = core_op_fn(i0, i1, i2, i3, i4)
121+
o0[...] = t0
122+
o1[...] = t1
123+
case _:
124+
func = None
125+
126+
if func is not None:
127+
return numba_basic.numba_njit(func)
128+
38129
inputs = [f"i{i}" for i in range(nin)]
39130
outputs = [f"o{i}" for i in range(nout)]
40131
inner_outputs = [f"t{output}" for output in outputs]
@@ -55,7 +146,7 @@ def store_core_outputs({inp_signature}, {out_signature}):
55146
func = compile_function_src(
56147
func_src, "store_core_outputs", {**globals(), **global_env}
57148
)
58-
return cast(Callable, numba_basic.numba_njit(func))
149+
return cast(Callable, numba_basic.numba_njit(func, cache=False))
59150

60151

61152
_jit_options = {

pytensor/link/numba/linker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ def fgraph_convert(self, fgraph, **kwargs):
1212
def jit_compile(self, fn):
1313
from pytensor.link.numba.dispatch.basic import numba_njit
1414

15-
jitted_fn = numba_njit(fn, no_cpython_wrapper=False, no_cfunc_wrapper=False)
15+
# NUMBA can't cache our dynamically generated funcified_fgraph
16+
jitted_fn = numba_njit(
17+
fn, no_cpython_wrapper=False, no_cfunc_wrapper=False, cache=False
18+
)
1619
return jitted_fn
1720

1821
def create_thunk_inputs(self, storage_map):

0 commit comments

Comments
 (0)