Skip to content

Commit e81657e

Browse files
committed
Generalizing XDot
1 parent 497c974 commit e81657e

File tree

3 files changed

+69
-58
lines changed

3 files changed

+69
-58
lines changed

pytensor/xtensor/math.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,15 @@ def make_node(self, x, y):
162162
x = as_xtensor(x)
163163
y = as_xtensor(y)
164164

165-
# Filter out contracted dimensions
166-
x_dims = [dim for dim in x.type.dims if dim not in self.dims]
167-
y_dims = [dim for dim in y.type.dims if dim not in self.dims]
168-
x_shape = [
169-
size for dim, size in zip(x.type.dims, x.type.shape) if dim not in self.dims
170-
]
171-
y_shape = [
172-
size for dim, size in zip(y.type.dims, y.type.shape) if dim not in self.dims
173-
]
174-
175-
# Combine remaining dimensions
176-
out_dims = tuple(x_dims + y_dims)
177-
out_shape = tuple(x_shape + y_shape)
165+
x_shape_dict = dict(zip(x.type.dims, x.type.shape))
166+
y_shape_dict = dict(zip(y.type.dims, y.type.shape))
167+
shape_dict = {**x_shape_dict, **y_shape_dict}
168+
169+
# Determine output dimensions
170+
out_dims = tuple(d for d in shape_dict if d not in self.dims)
171+
172+
# Determine output shape
173+
out_shape = tuple(shape_dict[d] for d in out_dims)
178174

179175
# Determine output dtype
180176
out_dtype = upcast(x.type.dtype, y.type.dtype)
@@ -183,7 +179,7 @@ def make_node(self, x, y):
183179
return Apply(self, [x, y], [out])
184180

185181

186-
def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None):
182+
def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
187183
"""Matrix multiplication between two XTensorVariables.
188184
189185
This operation performs matrix multiplication between two tensors, automatically
@@ -195,7 +191,7 @@ def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None):
195191
First input tensor
196192
y : XTensorVariable
197193
Second input tensor
198-
dims : str, Iterable[Hashable], EllipsisType, or None, optional
194+
dim : str, Iterable[Hashable], EllipsisType, or None, optional
199195
The dimensions to contract over. If None, will contract over all matching dimensions.
200196
If Ellipsis (...), will contract over all dimensions.
201197
@@ -214,40 +210,34 @@ def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None):
214210
x = as_xtensor(x)
215211
y = as_xtensor(y)
216212

213+
x_dims = set(x.type.dims)
214+
y_dims = set(y.type.dims)
215+
intersection = x_dims & y_dims
216+
union = x_dims | y_dims
217+
217218
# Canonicalize dims
218-
if isinstance(dims, str):
219-
dims = (dims,)
220-
elif isinstance(dims, Iterable):
221-
dims = tuple(dims)
219+
if dim is None:
220+
dim_set = intersection
221+
elif dim is ...:
222+
dim_set = union
223+
elif isinstance(dim, str):
224+
dim_set = {dim}
225+
elif isinstance(dim, Iterable):
226+
dim_set = set(dim)
222227

223228
# Validate provided dims
224-
if isinstance(dims, Iterable):
225-
for dim in dims:
226-
if dim not in x.type.dims:
227-
raise ValueError(
228-
f"Dimension {dim} not found in first input {x.type.dims}"
229-
)
230-
if dim not in y.type.dims:
231-
raise ValueError(
232-
f"Dimension {dim} not found in second input {y.type.dims}"
233-
)
234-
235-
# If dims is ... , we have to sum over all remaining axes
236-
sum_result = dims is ...
237-
238-
# Handle None and ... cases
239-
if dims is None or dims is ...:
240-
# Contract over all matching dimensions
241-
x_dims = set(x.type.dims)
242-
y_dims = set(y.type.dims)
243-
dims = tuple(x_dims & y_dims)
244-
245-
result = XDot(dims=dims)(x, y)
246-
247-
if sum_result:
248-
from pytensor.xtensor.reduction import sum as xtensor_sum
229+
# Check if any dimension is not found in either input
230+
for d in dim_set:
231+
if d not in union:
232+
raise ValueError(f"Dimension {d} not found in either input {y.type.dims}")
233+
234+
dotted_dims = tuple(dim_set & intersection)
235+
summed_dims = tuple(dim_set.difference(dotted_dims))
236+
237+
result = XDot(dims=dotted_dims)(x, y)
249238

239+
if summed_dims:
250240
# Sum over all remaining axes
251-
result = xtensor_sum(result, dim=...)
241+
result = result.sum(dim=summed_dims)
252242

253243
return result

pytensor/xtensor/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,9 @@ def stack(self, dim, **dims):
650650
def unstack(self, dim, **dims):
651651
return px.shape.unstack(self, dim, **dims)
652652

653-
def dot(self, other, dims=None):
653+
def dot(self, other, dim=None):
654654
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
655-
return px.math.dot(self, other, dims=dims)
655+
return px.math.dot(self, other, dim=dim)
656656

657657

658658
class XTensorConstantSignature(tuple):

tests/xtensor/test_math.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_dot():
168168
xr_assert_allclose(z_test, expected)
169169

170170
# Test matrix-vector dot product with ellipsis
171-
z = x.dot(y, dims=...)
171+
z = x.dot(y, dim=...)
172172
fn = xr_function([x, y], z)
173173
z_test = fn(x_test, y_test)
174174
expected = x_test.dot(y_test, dim=...)
@@ -186,22 +186,22 @@ def test_dot():
186186
expected = x_test.dot(y_test)
187187
xr_assert_allclose(z_test, expected)
188188

189-
# Test matrix-matrix dot product with string dims
190-
z = x.dot(y, dims="b")
189+
# Test matrix-matrix dot product with string dim
190+
z = x.dot(y, dim="b")
191191
fn = xr_function([x, y], z)
192192
z_test = fn(x_test, y_test)
193193
expected = x_test.dot(y_test, dim="b")
194194
xr_assert_allclose(z_test, expected)
195195

196196
# Test matrix-matrix dot product with list of dims
197-
z = x.dot(y, dims=["b"])
197+
z = x.dot(y, dim=["b"])
198198
fn = xr_function([x, y], z)
199199
z_test = fn(x_test, y_test)
200200
expected = x_test.dot(y_test, dim=["b"])
201201
xr_assert_allclose(z_test, expected)
202202

203203
# Test matrix-matrix dot product with ellipsis
204-
z = x.dot(y, dims=...)
204+
z = x.dot(y, dim=...)
205205
fn = xr_function([x, y], z)
206206
z_test = fn(x_test, y_test)
207207
expected = x_test.dot(y_test, dim=...)
@@ -220,28 +220,49 @@ def test_dot():
220220
xr_assert_allclose(z_test, expected)
221221

222222
# Same but with explicit dimensions
223-
z = x.dot(y, dims=["b", "c"])
223+
z = x.dot(y, dim=["b", "c"])
224224
fn = xr_function([x, y], z)
225225
z_test = fn(x_test, y_test)
226226
expected = x_test.dot(y_test, dim=["b", "c"])
227227
xr_assert_allclose(z_test, expected)
228228

229229
# Same but with ellipses
230-
z = x.dot(y, dims=...)
230+
z = x.dot(y, dim=...)
231231
fn = xr_function([x, y], z)
232232

233233
z_test = fn(x_test, y_test)
234234
expected = x_test.dot(y_test, dim=...)
235235
xr_assert_allclose(z_test, expected)
236236

237+
# Dot product with sum
238+
x_test = DataArray(np.arange(24.0).reshape(2, 3, 4), dims=("a", "b", "c"))
239+
y_test = DataArray(np.arange(60.0).reshape(3, 4, 5), dims=("b", "c", "d"))
240+
expected = x_test.dot(y_test, dim=("a", "b", "c"))
241+
242+
x = xtensor("x", dims=("a", "b", "c"), shape=(2, 3, 4))
243+
y = xtensor("y", dims=("b", "c", "d"), shape=(3, 4, 5))
244+
z = x.dot(y, dim=("a", "b", "c"))
245+
fn = xr_function([x, y], z)
246+
z_test = fn(x_test, y_test)
247+
xr_assert_allclose(z_test, expected)
248+
249+
return
250+
# Dot product with sum in the middle
251+
# This is not supported yet
252+
x_test = DataArray(np.arange(120).reshape(2, 3, 4, 5), dims=("a", "b", "c", "d"))
253+
y_test = DataArray(np.arange(360).reshape(3, 4, 5, 6), dims=("b", "c", "d", "e"))
254+
expected = x_test.dot(y_test, dim=("b", "d"))
255+
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 4, 5))
256+
y = xtensor("y", dims=("b", "c", "d", "e"), shape=(3, 4, 5, 6))
257+
z = x.dot(y, dim=("b", "d"))
258+
fn = xr_function([x, y], z)
259+
z_test = fn(x_test, y_test)
260+
xr_assert_allclose(z_test, expected)
261+
237262

238263
def test_dot_errors():
239264
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
240265
y = xtensor("y", dims=("b", "c"), shape=(3, 4))
241-
with pytest.raises(ValueError, match="Dimension c not found in first input"):
242-
x.dot(y, dims=["c"])
243-
with pytest.raises(ValueError, match="Dimension a not found in second input"):
244-
x.dot(y, dims=["a"])
245266

246267
# Test a case where there are no matching dimensions
247268
x_test = DataArray(np.ones((2, 3)), dims=("a", "b"))

0 commit comments

Comments
 (0)