Skip to content

Commit 0ce2cae

Browse files
fix signature, add infer_shape
1 parent 1bcf463 commit 0ce2cae

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pytensor/tensor/slinalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1704,7 +1704,7 @@ def _to_banded_form(A, kl, ku):
17041704

17051705
class BandedDot(Op):
17061706
__props__ = ("lower_diags", "upper_diags")
1707-
gufunc_signature = "(m,n),(n)->(n)"
1707+
gufunc_signature = "(m,n),(n)->(m)"
17081708

17091709
def __init__(self, lower_diags, upper_diags):
17101710
self.lower_diags = lower_diags
@@ -1719,6 +1719,9 @@ def make_node(self, A, b):
17191719

17201720
return pytensor.graph.basic.Apply(self, [A, B], [output])
17211721

1722+
def infer_shape(self, fgraph, nodes, shapes):
1723+
return [shapes[0][:-1]]
1724+
17221725
def perform(self, node, inputs, outputs_storage):
17231726
A, b = inputs
17241727
m, n = A.shape

0 commit comments

Comments
 (0)