Skip to content

Commit 9e324c7

Browse files
Propagate transpose option to numba dispatch
1 parent 79c161d commit 9e324c7

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ def numba_funcify_Solve(op, node, **kwargs):
10581058
check_finite = op.check_finite
10591059
overwrite_a = op.overwrite_a
10601060
overwrite_b = op.overwrite_b
1061-
transposed = False # TODO: Solve doesnt currently allow the transposed argument
1061+
transposed = op.transposed
10621062

10631063
dtype = node.inputs[0].dtype
10641064
if str(dtype).startswith("complex"):

tests/link/numba/test_slinalg.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,16 +401,39 @@ def lu_solve(a, b, trans, overwrite_a, overwrite_b):
401401
assert_allclose(x, x_sp)
402402

403403

404+
solve_test_cases = [
405+
("gen", False, False),
406+
("gen", False, True),
407+
("sym", False, False),
408+
("sym", True, False),
409+
("sym", True, True),
410+
("pos", False, False),
411+
("pos", True, False),
412+
("pos", True, True),
413+
]
414+
solve_test_ids = [
415+
f'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
416+
for assume_a, lower, transposed in solve_test_cases
417+
]
418+
419+
404420
@pytest.mark.parametrize(
405421
"b_shape",
406422
[(5, 1), (5, 5), (5,)],
407423
ids=["b_col_vec", "b_matrix", "b_vec"],
408424
)
409-
@pytest.mark.parametrize("assume_a", ["gen", "sym", "pos"], ids=str)
425+
@pytest.mark.parametrize(
426+
"assume_a, lower, transposed", solve_test_cases, ids=solve_test_ids
427+
)
410428
@pytest.mark.filterwarnings(
411429
'ignore:Cannot cache compiled function "numba_funcified_fgraph"'
412430
)
413-
def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
431+
def test_solve(
432+
b_shape: tuple[int],
433+
assume_a: Literal["gen", "sym", "pos"],
434+
lower: bool,
435+
transposed: bool,
436+
):
414437
A = pt.matrix("A", dtype=floatX)
415438
b = pt.tensor("b", shape=b_shape, dtype=floatX)
416439

@@ -424,10 +447,17 @@ def A_func(x):
424447
x = (x + x.T) / 2
425448
return x
426449

450+
def T(x):
451+
if transposed:
452+
return x.T
453+
return x
454+
427455
X = pt.linalg.solve(
428456
A_func(A),
429457
b,
430458
assume_a=assume_a,
459+
lower=lower,
460+
transposed=transposed,
431461
b_ndim=len(b_shape),
432462
)
433463
f = pytensor.function(
@@ -459,13 +489,18 @@ def A_func(x):
459489

460490
# Test that the result is numerically correct. Need to use the unmodified copy
461491
np.testing.assert_allclose(
462-
A_func(A_val_copy) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
492+
T(A_func(A_val_copy)) @ X_np, b_val_copy, atol=ATOL, rtol=RTOL
463493
)
464494

465495
# See the note in tensor/test_slinalg.py::test_solve_correctness for details about the setup here
466496
utt.verify_grad(
467497
lambda A, b: pt.linalg.solve(
468-
A_func(A), b, lower=False, assume_a=assume_a, b_ndim=len(b_shape)
498+
A_func(A),
499+
b,
500+
lower=lower,
501+
transposed=transposed,
502+
assume_a=assume_a,
503+
b_ndim=len(b_shape),
469504
),
470505
[A_val_copy, b_val_copy],
471506
mode="NUMBA",

0 commit comments

Comments
 (0)