Skip to content

Commit fb308ca

Browse files
committed
Numba BlockDiag: Fix failure with mixed readable/non-readable arrays
1 parent 52c9ef8 commit fb308ca

File tree

3 files changed

+103
-13
lines changed

3 files changed

+103
-13
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44

55
from pytensor import config
6+
from pytensor.link.numba.cache import compile_numba_function_src
67
from pytensor.link.numba.dispatch import basic as numba_basic
78
from pytensor.link.numba.dispatch.basic import (
89
generate_fallback_impl,
@@ -30,6 +31,10 @@
3031
from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
3132
from pytensor.link.numba.dispatch.linalg.solve.triangular import _solve_triangular
3233
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
34+
from pytensor.link.numba.dispatch.string_codegen import (
35+
CODE_TOKEN,
36+
build_source_code,
37+
)
3338
from pytensor.tensor.slinalg import (
3439
LU,
3540
QR,
@@ -222,24 +227,65 @@ def lu_factor(a):
222227

223228
@register_funcify_default_op_cache_key(BlockDiagonal)
224229
def numba_funcify_BlockDiagonal(op, node, **kwargs):
225-
dtype = node.outputs[0].dtype
230+
"""Codegen something like:
226231
227-
@numba_basic.numba_njit
228-
def block_diag(*arrs):
229-
shapes = np.array([a.shape for a in arrs], dtype="int")
230-
out_shape = [int(s) for s in np.sum(shapes, axis=0)]
231-
out = np.zeros((out_shape[0], out_shape[1]), dtype=dtype)
232+
def block_diagonal(arr0, arr1, arr2):
233+
out_r = arr0.shape[0] + arr1.shape[0] + arr2.shape[0]
234+
out_c = arr0.shape[1] + arr1.shape[1] + arr2.shape[1]
235+
out = np.zeros((out_r, out_c), dtype=np.float64)
232236
233237
r, c = 0, 0
234-
# no strict argument because it is incompatible with numba
235-
for arr, shape in zip(arrs, shapes):
236-
rr, cc = shape
237-
out[r : r + rr, c : c + cc] = arr
238-
r += rr
239-
c += cc
238+
rr, cc = arr0.shape
239+
out[r: r + rr, c: c + cc] = arr0
240+
r += rr
241+
c += cc
242+
243+
rr, cc = arr1.shape
244+
out[r: r + rr, c: c + cc] = arr1
245+
r += rr
246+
c += cc
247+
248+
rr, cc = arr2.shape
249+
out[r: r + rr, c: c + cc] = arr2
250+
r += rr
251+
c += cc
252+
240253
return out
254+
"""
255+
dtype = node.outputs[0].dtype
256+
n_inp = len(node.inputs)
257+
258+
arg_names = [f"arr{i}" for i in range(n_inp)]
259+
code = [
260+
f"def block_diagonal({', '.join(arg_names)}):",
261+
CODE_TOKEN.INDENT,
262+
f"out_r = {' + '.join(f'{a}.shape[0]' for a in arg_names)}",
263+
f"out_c = {' + '.join(f'{a}.shape[1]' for a in arg_names)}",
264+
f"out = np.zeros((out_r, out_c), dtype=np.{dtype})",
265+
CODE_TOKEN.EMPTY_LINE,
266+
"r, c = 0, 0",
267+
]
268+
for i, arg_name in enumerate(arg_names):
269+
code.extend(
270+
[
271+
f"rr, cc = {arg_name}.shape",
272+
f"out[r: r + rr, c: c + cc] = {arg_name}",
273+
"r += rr",
274+
"c += cc",
275+
CODE_TOKEN.EMPTY_LINE,
276+
]
277+
)
278+
code.append("return out")
241279

242-
return block_diag
280+
code_txt = build_source_code(code)
281+
block_diag = compile_numba_function_src(
282+
code_txt,
283+
"block_diagonal",
284+
globals() | {"np": np},
285+
)
286+
287+
cache_key = 1
288+
return numba_basic.numba_njit(block_diag), cache_key
243289

244290

245291
@register_funcify_default_op_cache_key(Solve)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,31 @@
1+
from collections.abc import Sequence
2+
from enum import Enum, auto
3+
4+
15
def create_tuple_string(x):
26
if len(x) == 1:
37
return f"({x[0]},)"
48
else:
59
return f"({', '.join(x)})"
10+
11+
12+
class CODE_TOKEN(Enum):
13+
INDENT = auto()
14+
DEDENT = auto()
15+
EMPTY_LINE = auto()
16+
17+
18+
def build_source_code(code: Sequence[str | CODE_TOKEN]) -> str:
19+
lines = []
20+
indentation_level = 0
21+
for line in code:
22+
if line is CODE_TOKEN.INDENT:
23+
indentation_level += 1
24+
elif line is CODE_TOKEN.DEDENT:
25+
indentation_level -= 1
26+
assert indentation_level >= 0
27+
elif line is CODE_TOKEN.EMPTY_LINE:
28+
lines.append("")
29+
else:
30+
lines.append(f"{' ' * indentation_level}{line}")
31+
return "\n".join(lines)

tests/link/numba/test_slinalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,24 @@ def test_block_diag():
811811
compare_numba_and_py([A, B, C, D], [X], [A_val, B_val, C_val, D_val])
812812

813813

814+
def test_block_diag_with_read_only_inp():
815+
# Regression test where numba would complain a about *args containing both read-only and regular inputs
816+
# Currently, constants are read-only for numba, but for future-proofing we add an explicitly read-only input as well
817+
x = pt.tensor("x", shape=(2, 2))
818+
x_read_only = pt.tensor("x", shape=(2, 2))
819+
x_const = pt.constant(np.ones((2, 2), dtype=x.type.dtype), name="x_read_only")
820+
out = pt.linalg.block_diag(x, x_read_only, x_const)
821+
822+
x_test = np.ones((2, 2), dtype=x.type.dtype)
823+
x_read_only_test = x_test.copy()
824+
x_read_only_test.flags.writeable = False
825+
compare_numba_and_py(
826+
[x, x_read_only],
827+
[out],
828+
[x_test, x_read_only_test],
829+
)
830+
831+
814832
@pytest.mark.parametrize("inverse", [True, False], ids=["p_inv", "p"])
815833
def test_pivot_to_permutation(inverse):
816834
from pytensor.tensor.slinalg import pivot_to_permutation

0 commit comments

Comments
 (0)