Skip to content

Commit ee47dcc

Browse files
committed
Numba BlockDiag: Fix failure with mixed readable/non-readable arrays
1 parent c48a8b3 commit ee47dcc

File tree

3 files changed

+107
-13
lines changed

3 files changed

+107
-13
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
from pytensor import config
6+
from pytensor.link.numba.cache import compile_numba_function_src
67
from pytensor.link.numba.dispatch import basic as numba_basic
78
from pytensor.link.numba.dispatch.basic import (
89
generate_fallback_impl,
@@ -30,6 +31,10 @@
3031
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
3132
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
3233
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
34+
from pytensor.link.numba.dispatch.string_codegen import (
35+
CODE_TOKEN,
36+
build_source_code,
37+
)
3338
from pytensor.tensor.slinalg import (
3439
LU,
3540
QR,
@@ -222,24 +227,69 @@ def lu_factor(a):
222227

223228
@register_funcify_default_op_cache_key(BlockDiagonal)
224229
def numba_funcify_BlockDiagonal(op, node, **kwargs):
225-
dtype = node.outputs[0].dtype
230+
"""
226231
227-
@numba_basic.numba_njit
228-
def block_diag(*arrs):
229-
shapes = np.array([a.shape for a in arrs], dtype="int")
230-
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
231-
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)
232+
Because we have variadic arguments we need to use codegen.
233+
234+
The generated code looks something like:
235+
236+
def block_diagonal(arr0, arr1, arr2):
237+
out_r = arr0.shape[0] + arr1.shape[0] + arr2.shape[0]
238+
out_c = arr0.shape[1] + arr1.shape[1] + arr2.shape[1]
239+
out = np.zeros((out_r, out_c), dtype=np.float64)
232240
233241
r, c = 0, 0
234-
# no strict argument because it is incompatible with numba
235-
for arr, shape in zip(arrs, shapes):
236-
rr, cc = shape
237-
out[r : r + rr, c : c + cc] = arr
238-
r += rr
239-
c += cc
242+
rr, cc = arr0.shape
243+
out[r: r + rr, c: c + cc] = arr0
244+
r += rr
245+
c += cc
246+
247+
rr, cc = arr1.shape
248+
out[r: r + rr, c: c + cc] = arr1
249+
r += rr
250+
c += cc
251+
252+
rr, cc = arr2.shape
253+
out[r: r + rr, c: c + cc] = arr2
254+
r += rr
255+
c += cc
256+
240257
return out
258+
"""
259+
dtype = node.outputs[0].dtype
260+
n_inp = len(node.inputs)
261+
262+
arg_names = [f"arr{i}" for i in range(n_inp)]
263+
code = [
264+
f"def block_diagonal({', '.join(arg_names)}):",
265+
CODE_TOKEN.INDENT,
266+
f"out_r = {' + '.join(f'{a}.shape[0]' for a in arg_names)}",
267+
f"out_c = {' + '.join(f'{a}.shape[1]' for a in arg_names)}",
268+
f"out = np.zeros((out_r, out_c), dtype=np.{dtype})",
269+
CODE_TOKEN.EMPTY_LINE,
270+
"r, c = 0, 0",
271+
]
272+
for i, arg_name in enumerate(arg_names):
273+
code.extend(
274+
[
275+
f"rr, cc = {arg_name}.shape",
276+
f"out[r: r + rr, c: c + cc] = {arg_name}",
277+
"r += rr",
278+
"c += cc",
279+
CODE_TOKEN.EMPTY_LINE,
280+
]
281+
)
282+
code.append("return out")
283+
284+
code_txt = build_source_code(code)
285+
block_diag = compile_numba_function_src(
286+
code_txt,
287+
"block_diagonal",
288+
globals() | {"np": np},
289+
)
241290

242-
return block_diag
291+
cache_key = 1
292+
return numba_basic.numba_njit(block_diag), cache_key
243293

244294

245295
@register_funcify_default_op_cache_key(Solve)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,31 @@
1+
from collections.abc import Sequence
2+
from enum import Enum, auto
3+
4+
15
def create_tuple_string(x):
26
if len(x) == 1:
37
return f"({x[0]},)"
48
else:
59
return f"({', '.join(x)})"
10+
11+
12+
class CODE_TOKEN(Enum):
13+
INDENT = auto()
14+
DEDENT = auto()
15+
EMPTY_LINE = auto()
16+
17+
18+
def build_source_code(code: Sequence[str | CODE_TOKEN]) -> str:
19+
lines = []
20+
indentation_level = 0
21+
for line in code:
22+
if line is CODE_TOKEN.INDENT:
23+
indentation_level += 1
24+
elif line is CODE_TOKEN.DEDENT:
25+
indentation_level -= 1
26+
assert indentation_level >= 0
27+
elif line is CODE_TOKEN.EMPTY_LINE:
28+
lines.append("")
29+
else:
30+
lines.append(f"{' ' * indentation_level}{line}")
31+
return "\n".join(lines)

tests/link/numba/test_slinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,24 @@ def test_block_diag():
811811
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
812812

813813

814+
def test_block_diag_with_read_only_inp():
815+
# Regression test where numba would complain a about *args containing both read-only and regular inputs
816+
# Currently, constants are read-only for numba, but for future-proofing we add an explicitly read-only input as well
817+
x = pt.tensor("x", shape=(2, 2))
818+
x_read_only = pt.tensor("x", shape=(2, 2))
819+
x_const = pt.constant(np.ones((2, 2), dtype=x.type.dtype), name="x_read_only")
820+
out = pt.linalg.block_diag(x, x_read_only, x_const)
821+
822+
x_test = np.ones((2, 2), dtype=x.type.dtype)
823+
x_read_only_test = x_test.copy()
824+
x_read_only_test.flags.writeable = False
825+
compare_numba_and_py(
826+
[x, x_read_only],
827+
[out],
828+
[x_test, x_read_only_test],
829+
)
830+
831+
814832
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
815833
def test_pivot_to_permutation(inverse):
816834
from pytensor.tensor.slinalg import pivot_to_permutation

0 commit comments

Comments
 (0)