Skip to content

Commit a5f6ce4

Browse files
committed
Raise ShapeError instead of NotImplementedError for unimplemented infer_shape in Op class
1 parent 707c82e commit a5f6ce4

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytensor/graph/op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,9 @@ def infer_shape(self, fgraph, node, input_shapes):
602602

603603
return _gufunc_to_out_shape(self.gufunc_signature, input_shapes)
604604
else:
605-
raise NotImplementedError(f"Op {self} does not implement infer_shape")
605+
from pytensor.tensor.exceptions import ShapeError
606+
607+
raise ShapeError(f"Op {self} does not implement infer_shape")
606608

607609
def __str__(self):
608610
return getattr(type(self), "__name__", super().__str__())

0 commit comments

Comments
 (0)