Skip to content

Commit 1ddd529

Browse files
Rename b to x, matching BLAS docs
1 parent 161e172 commit 1ddd529

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

pytensor/tensor/slinalg.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,20 +1678,20 @@ def __init__(self, lower_diags, upper_diags):
16781678
self.lower_diags = lower_diags
16791679
self.upper_diags = upper_diags
16801680

1681-
def make_node(self, A, b):
1681+
def make_node(self, A, x):
16821682
A = as_tensor_variable(A)
1683-
B = as_tensor_variable(b)
1683+
x = as_tensor_variable(x)
16841684

1685-
out_dtype = pytensor.scalar.upcast(A.dtype, B.dtype)
1686-
output = b.type.clone(dtype=out_dtype)()
1685+
out_dtype = pytensor.scalar.upcast(A.dtype, x.dtype)
1686+
output = x.type.clone(dtype=out_dtype)()
16871687

1688-
return pytensor.graph.basic.Apply(self, [A, B], [output])
1688+
return pytensor.graph.basic.Apply(self, [A, x], [output])
16891689

16901690
def infer_shape(self, fgraph, nodes, shapes):
16911691
return [shapes[0][:-1]]
16921692

16931693
def perform(self, node, inputs, outputs_storage):
1694-
A, b = inputs
1694+
A, x = inputs
16951695
m, n = A.shape
16961696

16971697
kl = self.lower_diags
@@ -1703,10 +1703,10 @@ def perform(self, node, inputs, outputs_storage):
17031703
A_banded[i, slice(k, None) if k >= 0 else slice(None, n + k)] = diag(A, k=k)
17041704

17051705
fn = scipy_linalg.get_blas_funcs("gbmv", dtype=A.dtype)
1706-
outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=b)
1706+
outputs_storage[0][0] = fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)
17071707

17081708

1709-
def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int):
1709+
def banded_dot(A: TensorLike, x: TensorLike, lower_diags: int, upper_diags: int):
17101710
"""
17111711
Specialized matrix-vector multiplication for cases when A is a banded matrix
17121712
@@ -1719,7 +1719,7 @@ def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int)
17191719
----------
17201720
A: Tensorlike
17211721
Matrix to perform banded dot on.
1722-
b: Tensorlike
1722+
x: Tensorlike
17231723
Vector to perform banded dot on.
17241724
lower_diags: int
17251725
Number of nonzero lower diagonals of A
@@ -1731,7 +1731,7 @@ def banded_dot(A: TensorLike, b: TensorLike, lower_diags: int, upper_diags: int)
17311731
out: Tensor
17321732
The matrix multiplication result
17331733
"""
1734-
return Blockwise(BandedDot(lower_diags, upper_diags))(A, b)
1734+
return Blockwise(BandedDot(lower_diags, upper_diags))(A, x)
17351735

17361736

17371737
__all__ = [

tests/tensor/test_slinalg.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,19 +1074,19 @@ def test_banded_dot(A_shape, kl, ku):
10741074
rng = np.random.default_rng()
10751075

10761076
A_val = _make_banded_A(rng.normal(size=A_shape), kl=kl, ku=ku).astype(config.floatX)
1077-
b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)
1077+
x_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)
10781078

10791079
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
1080-
b = pt.tensor("b", shape=b_val.shape, dtype=b_val.dtype)
1081-
res = banded_dot(A, b, kl, ku)
1082-
res_2 = A @ b
1080+
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)
1081+
res = banded_dot(A, x, kl, ku)
1082+
res_2 = A @ x
10831083

1084-
fn = function([A, b], [res, res_2], trust_input=True)
1084+
fn = function([A, x], [res, res_2], trust_input=True)
10851085
assert any(isinstance(node.op, BandedDot) for node in fn.maker.fgraph.apply_nodes)
10861086

1087-
x_val, x2_val = fn(A_val, b_val)
1087+
out_val, out_2_val = fn(A_val, x_val)
10881088

1089-
np.testing.assert_allclose(x_val, x2_val)
1089+
np.testing.assert_allclose(out_val, out_2_val)
10901090

10911091

10921092
@pytest.mark.parametrize("op", ["dot", "banded_dot"], ids=str)
@@ -1099,17 +1099,17 @@ def test_banded_dot_perf(op, A_shape, benchmark):
10991099
rng = np.random.default_rng()
11001100

11011101
A_val = _make_banded_A(rng.normal(size=A_shape), kl=1, ku=1).astype(config.floatX)
1102-
b_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)
1102+
x_val = rng.normal(size=(A_shape[-1],)).astype(config.floatX)
11031103

11041104
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
1105-
b = pt.tensor("b", shape=b_val.shape, dtype=A_val.dtype)
1105+
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)
11061106

11071107
if op == "dot":
11081108
f = pt.dot
11091109
elif op == "banded_dot":
11101110
f = functools.partial(banded_dot, lower_diags=1, upper_diags=1)
11111111

1112-
res = f(A, b)
1113-
fn = function([A, b], res, trust_input=True)
1112+
res = f(A, x)
1113+
fn = function([A, x], res, trust_input=True)
11141114

1115-
benchmark(fn, A_val, b_val)
1115+
benchmark(fn, A_val, x_val)

0 commit comments

Comments
 (0)