We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 262d3aa commit bbd12a7Copy full SHA for bbd12a7
pytensor/tensor/math.py
@@ -3025,6 +3025,11 @@ def make_node(self, *inputs):
3025
)
3026
3027
sx, sy = (input.type.shape for input in inputs)
3028
+ if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
3029
+ raise ValueError(
3030
+ f"Incompatible shared dimension for dot product: {sx}, {sy}"
3031
+ )
3032
+
3033
if len(sy) == 2:
3034
sz = sx[:-1] + sy[-1:]
3035
elif len(sy) == 1:
0 commit comments