Skip to content

Commit a6b729e

Browse files
committed
Make Dot only accept matrix inputs
1 parent d36d480 commit a6b729e

File tree

7 files changed

+91
-191
lines changed

7 files changed

+91
-191
lines changed

pytensor/tensor/math.py

Lines changed: 49 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,13 @@
4040
get_normalized_batch_axes,
4141
scalar_elemwise,
4242
)
43-
from pytensor.tensor.shape import shape, specify_broadcastable
43+
from pytensor.tensor.shape import shape, specify_shape
4444
from pytensor.tensor.type import (
4545
DenseTensorType,
4646
complex_dtypes,
4747
continuous_dtypes,
4848
discrete_dtypes,
49+
float_dtypes,
4950
int_dtypes,
5051
tensor,
5152
uint_dtypes,
@@ -2986,9 +2987,7 @@ def clip(x, min, max):
29862987

29872988
class Dot(Op):
29882989
"""
2989-
Computes the dot product of two variables. For two matrices, this is
2990-
equivalent to matrix multiplication. For two vectors, this is the inner
2991-
product.
2990+
Computes the dot product of two matrices variables
29922991
29932992
Notes
29942993
-----
@@ -3001,97 +3000,57 @@ class Dot(Op):
30013000
30023001
"""
30033002

3003+
gufunc_signature = "(m,n),(n,p)->(m,p)"
3004+
gufunc_spec = ("np.matmul", 2, 1)
30043005
__props__ = ()
30053006

3006-
# the rationale for Dot22 is related to getting GEMM Ops into the
3007-
# graph. See Dot22 in tensor.blas for details.
3008-
3009-
def make_node(self, *inputs):
3010-
inputs = list(map(as_tensor_variable, inputs))
3007+
def make_node(self, x, y):
3008+
x = as_tensor_variable(x)
3009+
y = as_tensor_variable(y)
30113010

3012-
if len(inputs) != 2:
3013-
raise TypeError(f"Two arguments required, {len(inputs)} given ")
3014-
if inputs[0].ndim not in (1, 2):
3011+
if x.type.ndim != 2:
30153012
raise TypeError(
3016-
"Input 0 (0-indexed) must have ndim of "
3017-
f"1 or 2, {int(inputs[0].ndim)} given. Consider calling "
3018-
"pytensor.tensor.dot instead."
3013+
f"Dot Op expects a 2D tensor as input 0, got {x} with {x.type.ndim} dimensions"
30193014
)
3020-
if inputs[1].ndim not in (1, 2):
3015+
if y.type.ndim != 2:
30213016
raise TypeError(
3022-
"Input 1 (0-indexed) must have ndim of "
3023-
f"1 or 2, {int(inputs[1].ndim)} given. Consider calling "
3024-
"pytensor.tensor.dot instead."
3017+
f"Dot Op expects a 2D tensor as input 1, got {y} with {y.type.ndim} dimensions"
30253018
)
30263019

3027-
sx, sy = (input.type.shape for input in inputs)
3020+
sx, sy = x.type.shape, y.type.shape
30283021
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
30293022
raise ValueError(
30303023
f"Incompatible shared dimension for dot product: {sx}, {sy}"
30313024
)
3025+
sz = sx[:-1] + sy[-1:]
3026+
outputs = [tensor(dtype=ps.upcast(x.type.dtype, y.type.dtype), shape=sz)]
3027+
return Apply(self, [x, y], outputs)
30323028

3033-
if len(sy) == 2:
3034-
sz = sx[:-1] + sy[-1:]
3035-
elif len(sy) == 1:
3036-
sz = sx[:-1]
3037-
3038-
i_dtypes = [input.type.dtype for input in inputs]
3039-
outputs = [tensor(dtype=ps.upcast(*i_dtypes), shape=sz)]
3040-
return Apply(self, inputs, outputs)
3041-
3042-
def perform(self, node, inp, out):
3043-
x, y = inp
3044-
(z,) = out
3045-
3046-
# the asarray is here because dot between two vectors
3047-
# gives a numpy float object but we need to return a 0d
3048-
# ndarray
3049-
z[0] = np.asarray(np.dot(x, y))
3029+
def perform(self, node, inputs, output_storage):
3030+
output_storage[0][0] = np.matmul(*inputs)
30503031

30513032
def grad(self, inp, grads):
30523033
x, y = inp
30533034
(gz,) = grads
3054-
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
3055-
3056-
# grad is scalar, so x is vector and y is vector
3057-
if gdim == 0:
3058-
xgrad = gz * y
3059-
ygrad = gz * x
3060-
3061-
# x is vector, y is matrix, grad is vector
3062-
elif xdim == 1 and ydim == 2:
3063-
xgrad = dot(gz, y.T)
3064-
ygrad = outer(x.T, gz)
30653035

3066-
# x is matrix, y is vector, grad is vector
3067-
elif xdim == 2 and ydim == 1:
3068-
xgrad = outer(gz, y.T)
3069-
ygrad = dot(x.T, gz)
3070-
3071-
# x is matrix, y is matrix, grad is matrix
3072-
elif xdim == ydim == 2:
3073-
xgrad = dot(gz, y.T)
3074-
ygrad = dot(x.T, gz)
3036+
xgrad = self(gz, y.T)
3037+
ygrad = self(x.T, gz)
30753038

30763039
# If x or y contain broadcastable dimensions but only one of
30773040
# them know that a matching dimensions is broadcastable, the
30783041
# above code don't always return the right broadcast pattern.
30793042
# This cause problem down the road. See gh-1461.
3080-
if xgrad.broadcastable != x.broadcastable:
3081-
xgrad = specify_broadcastable(
3082-
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
3083-
)
3084-
if ygrad.broadcastable != y.broadcastable:
3085-
ygrad = specify_broadcastable(
3086-
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
3087-
)
3043+
if xgrad.type.shape != x.type.shape:
3044+
xgrad = specify_shape(xgrad, x.type.shape)
3045+
if ygrad.type.shape != y.type.shape:
3046+
ygrad = specify_shape(ygrad, y.type.shape)
30883047

3089-
rval = xgrad, ygrad
3048+
if xgrad.type.dtype not in float_dtypes:
3049+
raise TypeError("Dot grad x output must be a float type")
3050+
if ygrad.type.dtype not in float_dtypes:
3051+
raise TypeError("Dot grad y output must be a float type")
30903052

3091-
for elem in rval:
3092-
assert elem.dtype.find("float") != -1
3093-
3094-
return rval
3053+
return xgrad, ygrad
30953054

30963055
def R_op(self, inputs, eval_points):
30973056
# R_op for a \dot b evaluated at c for a and d for b is
@@ -3116,24 +3075,7 @@ def R_op(self, inputs, eval_points):
31163075

31173076
def infer_shape(self, fgraph, node, shapes):
31183077
xshp, yshp = shapes
3119-
x, y = node.inputs
3120-
3121-
# vector / vector
3122-
if x.ndim == 1 and y.ndim == 1:
3123-
return [()]
3124-
# matrix / vector
3125-
if x.ndim == 2 and y.ndim == 1:
3126-
return [xshp[:-1]]
3127-
# vector / matrix
3128-
if x.ndim == 1 and y.ndim == 2:
3129-
return [yshp[-1:]]
3130-
# matrix / matrix
3131-
if x.ndim == 2 and y.ndim == 2:
3132-
return [xshp[:-1] + yshp[-1:]]
3133-
raise NotImplementedError()
3134-
3135-
def __str__(self):
3136-
return "dot"
3078+
return [[xshp[0], yshp[1]]]
31373079

31383080

31393081
_dot = Dot()
@@ -3215,7 +3157,24 @@ def dense_dot(a, b):
32153157
elif a.ndim > 2 or b.ndim > 2:
32163158
return tensordot(a, b, [[a.ndim - 1], [np.maximum(0, b.ndim - 2)]])
32173159
else:
3218-
return _dot(a, b)
3160+
row_vector = a.ndim == 1
3161+
if row_vector:
3162+
# Promote to row matrix
3163+
a = a[None]
3164+
3165+
col_vector = b.ndim == 1
3166+
if col_vector:
3167+
# Promote to column matrix
3168+
b = b[:, None]
3169+
3170+
out = _dot(a, b)
3171+
if row_vector:
3172+
# If we promoted a to a row matrix, we need to squeeze the first dimension
3173+
out = out.squeeze(0)
3174+
if col_vector:
3175+
# If we promoted b to a column matrix, we need to squeeze the last dimension
3176+
out = out.squeeze(-1)
3177+
return out
32193178

32203179

32213180
def tensordot(
@@ -3921,11 +3880,7 @@ def logsumexp(x, axis=None, keepdims=False):
39213880
return log(sum(exp(x), axis=axis, keepdims=keepdims))
39223881

39233882

3924-
_matmul = Blockwise(
3925-
_dot,
3926-
signature="(m,k),(k,n)->(m,n)",
3927-
gufunc_spec=("numpy.matmul", 2, 1),
3928-
)
3883+
_matmul = Blockwise(_dot, name="matmul")
39293884

39303885

39313886
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):

pytensor/tensor/rewriting/blas.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@
107107
)
108108
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
109109
from pytensor.tensor.type import (
110-
DenseTensorType,
111110
TensorType,
112111
integer_dtypes,
113112
values_eq_approx_remove_inf_nan,
@@ -580,29 +579,14 @@ def print_profile(cls, stream, prof, level=0):
580579
def local_dot_to_dot22(fgraph, node):
581580
# This works for tensor.outer too because basic.outer is a macro that
582581
# produces a dot(dimshuffle,dimshuffle) of form 4 below
583-
if not isinstance(node.op, Dot):
584-
return
585-
586-
if any(not isinstance(i.type, DenseTensorType) for i in node.inputs):
587-
return False
588-
589582
x, y = node.inputs
590583
if y.type.dtype != x.type.dtype:
591584
# TODO: upcast one so the types match
592585
_logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}")
593586
return
594587

595588
if y.type.dtype in ("float16", "float32", "float64", "complex64", "complex128"):
596-
if x.ndim == 2 and y.ndim == 2:
597-
new_out = [_dot22(*node.inputs)]
598-
elif x.ndim == 2 and y.ndim == 1:
599-
new_out = [_dot22(x, y.dimshuffle(0, "x")).dimshuffle(0)]
600-
elif x.ndim == 1 and y.ndim == 2:
601-
new_out = [_dot22(x.dimshuffle("x", 0), y).dimshuffle(1)]
602-
elif x.ndim == 1 and y.ndim == 1:
603-
new_out = [_dot22(x.dimshuffle("x", 0), y.dimshuffle(0, "x")).dimshuffle()]
604-
else:
605-
return
589+
new_out = [_dot22(*node.inputs)]
606590
copy_stack_trace(node.outputs, new_out)
607591
return new_out
608592

pytensor/tensor/rewriting/math.py

Lines changed: 25 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
node_rewriter,
2020
)
2121
from pytensor.graph.rewriting.utils import get_clients_at_depth
22-
from pytensor.raise_op import assert_op
2322
from pytensor.tensor.basic import (
2423
Alloc,
2524
Join,
@@ -34,6 +33,7 @@
3433
ones_like,
3534
register_infer_shape,
3635
switch,
36+
zeros,
3737
zeros_like,
3838
)
3939
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
@@ -44,12 +44,10 @@
4444
Prod,
4545
Sum,
4646
_conj,
47-
_dot,
4847
_matmul,
4948
add,
5049
digamma,
5150
dot,
52-
eq,
5351
erf,
5452
erfc,
5553
exp,
@@ -130,16 +128,12 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
130128
return consts, origconsts, nonconsts
131129

132130

133-
@register_canonicalize
134-
@register_stabilize
131+
@register_canonicalize("shape_unsafe")
132+
@register_stabilize("shape_unsafe")
135133
@node_rewriter([Dot])
136134
def local_0_dot_x(fgraph, node):
137-
if not isinstance(node.op, Dot):
138-
return False
139-
140-
x = node.inputs[0]
141-
y = node.inputs[1]
142-
replace = (
135+
x, y = node.inputs
136+
if (
143137
get_underlying_scalar_constant_value(
144138
x, only_process_constants=True, raise_not_constant=False
145139
)
@@ -148,26 +142,12 @@ def local_0_dot_x(fgraph, node):
148142
y, only_process_constants=True, raise_not_constant=False
149143
)
150144
== 0
151-
)
152-
153-
if replace:
154-
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
155-
if x.ndim == 2 and y.ndim == 2:
156-
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
157-
return [alloc(constant_zero, x.shape[0], y.shape[1])]
158-
elif x.ndim == 1 and y.ndim == 2:
159-
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
160-
return [alloc(constant_zero, y.shape[1])]
161-
elif x.ndim == 2 and y.ndim == 1:
162-
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
163-
return [alloc(constant_zero, x.shape[0])]
164-
elif x.ndim == 1 and y.ndim == 1:
165-
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
166-
return [constant_zero]
145+
):
146+
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
167147

168148

169149
@register_canonicalize
170-
@node_rewriter([DimShuffle])
150+
@node_rewriter([Dot, _matmul])
171151
def local_lift_transpose_through_dot(fgraph, node):
172152
r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``.
173153
@@ -176,22 +156,24 @@ def local_lift_transpose_through_dot(fgraph, node):
176156
and to later merge consecutive `DimShuffle`\s.
177157
"""
178158

179-
if not (
180-
is_matrix_transpose(node.out)
181-
and node.inputs[0].owner
182-
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul))
183-
):
184-
return False
159+
clients = fgraph.clients[node.out]
160+
if len(clients) != 1:
161+
# If the dot is used in more than one place, we don't want to duplicate it
162+
return None
185163

186-
x, y = node.inputs[0].owner.inputs
164+
[(client, _)] = clients
187165

188-
if x.ndim >= y.ndim >= 2:
189-
# Output is dot product of transposed inputs in reverse order
190-
ret = [dot_op(y.mT, x.mT)]
166+
if not (isinstance(client.op, DimShuffle) and is_matrix_transpose(client.out)):
167+
return None
191168

192-
# Copy over stack trace to output from result of dot-product
193-
copy_stack_trace(node.inputs[0], ret)
194-
return ret
169+
x, y = node.inputs
170+
# Output is dot product of transposed inputs in reverse order
171+
ret = node.op(y.mT, x.mT)
172+
173+
# Copy over stack trace to output from result of dot-product
174+
copy_stack_trace(node.out, ret)
175+
176+
return {client.out: ret}
195177

196178

197179
def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool):
@@ -344,21 +326,14 @@ def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
344326

345327
@register_canonicalize
346328
@register_specialize
347-
@node_rewriter([_matmul, _dot])
329+
@node_rewriter([_matmul, Dot])
348330
def local_dot_to_mul(fgraph, node):
349331
"""Rewrite blockwise dots that correspond to multiplication without summation."""
350332
a, b = node.inputs
351333
a_static_shape = a.type.shape
352334
b_static_shape = b.type.shape
353335

354-
if isinstance(node.op, Dot) and (
355-
len(a_static_shape) != 2 or len(b_static_shape) != 2
356-
):
357-
# For now, we only support matrix-matrix multiplication
358-
# We should eventually canonicalize all dots to this form
359-
return None
360-
361-
# Check if we have matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
336+
# Check if we have (..., m, 1) * (..., 1, n) -> (..., m, n)
362337
if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1):
363338
return None
364339

0 commit comments

Comments
 (0)