Skip to content

Commit 8d30a29

Browse files
Add suggestions
1 parent 62ccf13 commit 8d30a29

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pytensor/tensor/slinalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1679,6 +1679,11 @@ def __init__(self, lower_diags, upper_diags):
16791679
self.upper_diags = upper_diags
16801680

16811681
def make_node(self, A, x):
1682+
if A.ndim != 2:
1683+
raise TypeError("A must be a 2D tensor")
1684+
if x.ndim != 1:
1685+
raise TypeError("x must be a 1D tensor")
1686+
16821687
A = as_tensor_variable(A)
16831688
x = as_tensor_variable(x)
16841689

@@ -1688,7 +1693,8 @@ def make_node(self, A, x):
16881693
return pytensor.graph.basic.Apply(self, [A, x], [output])
16891694

16901695
def infer_shape(self, fgraph, nodes, shapes):
1691-
return [shapes[0][:-1]]
1696+
A_shape, _ = shapes
1697+
return [(A_shape[0],)]
16921698

16931699
def perform(self, node, inputs, outputs_storage):
16941700
A, x = inputs

0 commit comments

Comments
 (0)