Skip to content

Commit af1ffb9

Browse files
Propagate transpose option to numba dispatch
1 parent c8d1fe1 commit af1ffb9

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ def numba_funcify_Solve(op, node, **kwargs):
10521052
check_finite = op.check_finite
10531053
overwrite_a = op.overwrite_a
10541054
overwrite_b = op.overwrite_b
1055-
transposed = False # TODO: Solve doesnt currently allow the transposed argument
1055+
transposed = op.transposed
10561056

10571057
dtype = node.inputs[0].dtype
10581058
if str(dtype).startswith("complex"):

tests/link/numba/test_slinalg.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,16 +358,39 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
358358
assert_allclose(x, x_sp)
359359

360360

361+
solve_test_cases = [
362+
("gen", False, False),
363+
("gen", False, True),
364+
("sym", False, False),
365+
("sym", True, False),
366+
("sym", True, True),
367+
("pos", False, False),
368+
("pos", True, False),
369+
("pos", True, True),
370+
]
371+
solve_test_ids = [
372+
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
373+
for assume_a, lower, transposed in solve_test_cases
374+
]
375+
376+
361377
@pytest.mark.parametrize(
362378
"b_shape",
363379
[(5, 1), (5, 5), (5,)],
364380
ids=["b_col_vec", "b_matrix", "b_vec"],
365381
)
366-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
382+
@pytest.mark.parametrize(
383+
"assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids
384+
)
367385
@pytest.mark.filterwarnings(
368386
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
369387
)
370-
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
388+
def test_solve(
389+
b_shape: tuple[int],
390+
assume_a: Literal["gen", "sym", "pos"],
391+
lower: bool,
392+
transposed: bool,
393+
):
371394
A = pt.matrix("A", dtype=floatX)
372395
b = pt.tensor("b", shape=b_shape, dtype=floatX)
373396

@@ -381,10 +404,17 @@ def A_func(x):
381404
x = (x + x.T) / 2
382405
return x
383406

407+
def T(x):
408+
if transposed:
409+
return x.T
410+
return x
411+
384412
X = pt.linalg.solve(
385413
A_func(A),
386414
b,
387415
assume_a=assume_a,
416+
lower=lower,
417+
transposed=transposed,
388418
b_ndim=len(b_shape),
389419
)
390420
f = pytensor.function(
@@ -416,13 +446,18 @@ def A_func(x):
416446

417447
# Test that the result is numerically correct. Need to use the unmodified copy
418448
np.testing.assert_allclose(
419-
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
449+
T(A_func(A_val_copy)) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
420450
)
421451

422452
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
423453
utt.verify_grad(
424454
lambda A, b: pt.linalg.solve(
425-
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
455+
A_func(A),
456+
b,
457+
lower=lower,
458+
transposed=transposed,
459+
assume_a=assume_a,
460+
b_ndim=len(b_shape),
426461
),
427462
[A_val_copy, b_val_copy],
428463
mode="NUMBA",

0 commit comments

Comments
 (0)