Skip to content

Commit 2c7f980

Browse files
committed
.WIP
1 parent cac3cd5 commit 2c7f980

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

pytensor/tensor/rewriting/blas.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node):
916916
"""
917917
x, y = node.inputs
918918

919+
if x.type.ndim < 3:
920+
# This doesn't actually have a batch dimension
921+
return None
922+
919923
# BatchedDot does not allow implicit broadcasting of the batch dimensions
920924
# We do not want to explicitly broadcast as it may result in huge arrays
921925
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
@@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
926930
if len(x_shape) > 3:
927931
# If we have more than one batch dim, ravel it
928932
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
933+
if len(y_shape) > 3:
929934
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
930935

931936
new_out = _batched_dot(x, y)

pytensor/tensor/rewriting/math.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,15 @@ def local_lift_transpose_through_dot(fgraph, node):
194194
return ret
195195

196196

197-
@register_canonicalize
198-
@register_specialize
199-
@node_rewriter(tracks=[_matmul])
200-
def local_batched_matmul_to_core_matmul(fgraph, node):
197+
def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool):
201198
"""Move batch dimensions of matmul operands to core matmul
202199
203200
Example, if x has batch dimensions that don't overlap with batch dimensions of y
204201
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
205202
206203
It also works for batch dimensions of y that don't overlap with batch dimensions of x
204+
205+
The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
207206
"""
208207

209208
x, y = node.inputs
@@ -247,6 +246,9 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
247246
# x was a row matrix, squeeze it to clean up the graph
248247
x_stacked = x_stacked.squeeze(-2)
249248
if n_x_axis_to_merge > 1 or not x_is_row:
249+
if not allow_reshape:
250+
return None
251+
250252
# Ravel moved batch dims together with (m) if needed
251253
x_stacked_shape = tuple(x_stacked.shape)
252254
x_stacked = x_stacked.reshape(
@@ -262,6 +264,8 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
262264
# y was a column matrix, squeeze it to clean up the graph
263265
y_stacked = y_stacked.squeeze(-1)
264266
if n_y_axis_to_merge > 1 or not y_is_col:
267+
if not allow_reshape:
268+
return False
265269
# Ravel moved batch dims together with (n) if needed
266270
y_stacked_shape = tuple(y_stacked.shape)
267271
y_stacked = y_stacked.reshape(
@@ -319,6 +323,21 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
319323
return [out]
320324

321325

326+
@register_canonicalize
327+
@node_rewriter(tracks=[_matmul])
328+
def local_batched_matmul_to_core_matmul(fgraph, node):
329+
# Allow passing batch dimensions of matmul to core vector / column matrices
330+
return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=False)
331+
332+
333+
@register_specialize
334+
@node_rewriter(tracks=[_matmul])
335+
def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
336+
# Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
337+
# We only apply this in specialize, because grahs with reshape are hard to work with
338+
return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=True)
339+
340+
322341
@register_canonicalize
323342
@register_specialize
324343
@node_rewriter([_matmul])

tests/tensor/rewriting/test_math.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4661,7 +4661,15 @@ def count_matvec_nodes(graph):
46614661
assert count_matvec_nodes(out) == 1
46624662

46634663
rewritten_out = rewrite_graph(
4664-
out, exclude=("local_eager_useless_unbatched_blockwise",)
4664+
out,
4665+
include=(
4666+
"canonicalize",
4667+
"specialize",
4668+
),
4669+
exclude=(
4670+
"local_eager_useless_unbatched_blockwise",
4671+
"specialize_matmul_to_batched_dot",
4672+
),
46654673
)
46664674
# No `matvec` in the rewritten out if one of the vector can be treated as a matrix
46674675
expected = not any(
@@ -4675,7 +4683,9 @@ def count_matvec_nodes(graph):
46754683
for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2])
46764684
)
46774685

4678-
assert count_matvec_nodes(rewritten_out) == expected
4686+
assert count_matvec_nodes(rewritten_out) == expected, rewritten_out.dprint(
4687+
print_shape=True
4688+
)
46794689

46804690
rng = np.random.default_rng(mat_shape + vec_shape)
46814691
eval_dict = {mat: rng.random(mat.type.shape), vec: rng.random(vec.type.shape)}

0 commit comments

Comments
 (0)