Skip to content

Commit 7b44445

Browse files
committed
Implement infer_shape method in Op class and remove redundant implementations in subclasses
1 parent c161452 commit 7b44445

File tree

3 files changed

+7
-32
lines changed

3 files changed

+7
-32
lines changed

pytensor/graph/op.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
add_tag_trace,
2121
get_variable_trace_string,
2222
)
23+
from pytensor.tensor.utils import _gufunc_to_out_shape
2324

2425

2526
if TYPE_CHECKING:
@@ -596,6 +597,12 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
596597
# By default, do nothing
597598
return self
598599

600+
def infer_shape(self, fgraph, node, input_shapes):
601+
if hasattr(self, "gufunc_signature"):
602+
return _gufunc_to_out_shape(self.gufunc_signature, input_shapes)
603+
else:
604+
raise NotImplementedError(f"Op {self} does not implement infer_shape")
605+
599606
def __str__(self):
600607
return getattr(type(self), "__name__", super().__str__())
601608

pytensor/tensor/nlinalg.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from pytensor.tensor.basic import as_tensor_variable, diagonal
1818
from pytensor.tensor.blockwise import Blockwise
1919
from pytensor.tensor.type import Variable, dvector, lscalar, matrix, scalar, vector
20-
from pytensor.tensor.utils import _gufunc_to_out_shape
2120

2221

2322
class MatrixPinv(Op):
@@ -63,9 +62,6 @@ def L_op(self, inputs, outputs, g_outputs):
6362
).T
6463
return [grad]
6564

66-
def infer_shape(self, fgraph, node, shapes):
67-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
68-
6965

7066
def pinv(x, hermitian=False):
7167
"""Computes the pseudo-inverse of a matrix :math:`A`.
@@ -156,9 +152,6 @@ def R_op(self, inputs, eval_points):
156152
return [None]
157153
return [-matrix_dot(xi, ev, xi)]
158154

159-
def infer_shape(self, fgraph, node, shapes):
160-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
161-
162155

163156
inv = matrix_inverse = Blockwise(MatrixInverse())
164157

@@ -225,9 +218,6 @@ def grad(self, inputs, g_outputs):
225218
(x,) = inputs
226219
return [gz * self(x) * matrix_inverse(x).T]
227220

228-
def infer_shape(self, fgraph, node, shapes):
229-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
230-
231221
def __str__(self):
232222
return "Det"
233223

@@ -259,9 +249,6 @@ def perform(self, node, inputs, outputs):
259249
except Exception as e:
260250
raise ValueError("Failed to compute determinant", x) from e
261251

262-
def infer_shape(self, fgraph, node, shapes):
263-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
264-
265252
def __str__(self):
266253
return "SLogDet"
267254

@@ -317,9 +304,6 @@ def perform(self, node, inputs, outputs):
317304
(w, v) = outputs
318305
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
319306

320-
def infer_shape(self, fgraph, node, shapes):
321-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
322-
323307

324308
eig = Blockwise(Eig())
325309

pytensor/tensor/slinalg.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pytensor.tensor.nlinalg import kron, matrix_dot
2121
from pytensor.tensor.shape import reshape
2222
from pytensor.tensor.type import matrix, tensor, vector
23-
from pytensor.tensor.utils import _gufunc_to_out_shape
2423
from pytensor.tensor.variable import TensorVariable
2524

2625

@@ -51,9 +50,6 @@ def __init__(
5150
if self.overwrite_a:
5251
self.destroy_map = {0: [0]}
5352

54-
def infer_shape(self, fgraph, node, shapes):
55-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
56-
5753
def make_node(self, x):
5854
x = as_tensor_variable(x)
5955
if x.type.ndim != 2:
@@ -269,9 +265,6 @@ def make_node(self, A, b):
269265
x = tensor(dtype=o_dtype, shape=b.type.shape)
270266
return Apply(self, [A, b], [x])
271267

272-
def infer_shape(self, fgraph, node, shapes):
273-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
274-
275268
def L_op(self, inputs, outputs, output_gradients):
276269
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.
277270
@@ -885,9 +878,6 @@ def perform(self, node, inputs, output_storage):
885878
out_dtype = node.outputs[0].type.dtype
886879
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)
887880

888-
def infer_shape(self, fgraph, node, shapes):
889-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
890-
891881
def grad(self, inputs, output_grads):
892882
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
893883
# Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q,
@@ -957,9 +947,6 @@ def perform(self, node, inputs, output_storage):
957947
out_dtype
958948
)
959949

960-
def infer_shape(self, fgraph, node, shapes):
961-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
962-
963950
def grad(self, inputs, output_grads):
964951
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
965952
A, Q = inputs
@@ -1077,9 +1064,6 @@ def perform(self, node, inputs, output_storage):
10771064
out_dtype = node.outputs[0].type.dtype
10781065
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)
10791066

1080-
def infer_shape(self, fgraph, node, shapes):
1081-
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
1082-
10831067
def grad(self, inputs, output_grads):
10841068
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
10851069
A, B, Q, R = inputs

0 commit comments

Comments
 (0)