We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent bf628c9 commit 979ed32Copy full SHA for 979ed32
pytensor/tensor/rewriting/math.py
@@ -295,6 +295,10 @@ def local_blockwise_dot_to_mul(fgraph, node):
295
new_b = b
296
else:
297
return None
298
+
299
+ # new condition to handle (1,1) @ (1,1)
300
+ if a.ndim == 2 and b.ndim == 2 and a.shape == (1, 1) and b.shape == (1, 1):
301
+ return [a * b] # Direct elementwise multiplication
302
303
new_a = copy_stack_trace(a, new_a)
304
new_b = copy_stack_trace(b, new_b)
0 commit comments