Skip to content

Commit 979ed32

Browse files
Rewrite scalar dot as multiplication #1205
1 parent bf628c9 commit 979ed32

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ def local_blockwise_dot_to_mul(fgraph, node):
295295
new_b = b
296296
else:
297297
return None
298+
299+
# new condition to handle (1,1) @ (1,1)
300+
if a.ndim == 2 and b.ndim == 2 and a.shape == (1, 1) and b.shape == (1, 1):
301+
return [a * b] # Direct elementwise multiplication
298302

299303
new_a = copy_stack_trace(a, new_a)
300304
new_b = copy_stack_trace(b, new_b)

0 commit comments

Comments
 (0)