Skip to content

Commit 497c974

Browse files
committed
Compose XDot and Sum
1 parent 23b1799 commit 497c974

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

pytensor/xtensor/math.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,12 @@ class XDot(XOp):
150150
----------
151151
dims : tuple of str
152152
The dimensions to contract over. If None, will contract over all matching dimensions.
153-
sum_result : bool
154-
If True, sum over all remaining axes after contraction (for full contraction, e.g. dims=...).
155153
"""
156154

157-
__props__ = ("dims", "sum_result")
155+
__props__ = ("dims",)
158156

159-
def __init__(self, dims: Iterable[str], sum_result: bool = False):
157+
def __init__(self, dims: Iterable[str]):
160158
self.dims = dims
161-
self.sum_result = sum_result
162159
super().__init__()
163160

164161
def make_node(self, x, y):
@@ -176,12 +173,8 @@ def make_node(self, x, y):
176173
]
177174

178175
# Combine remaining dimensions
179-
if self.sum_result:
180-
out_dims = ()
181-
out_shape = ()
182-
else:
183-
out_dims = tuple(x_dims + y_dims)
184-
out_shape = tuple(x_shape + y_shape)
176+
out_dims = tuple(x_dims + y_dims)
177+
out_shape = tuple(x_shape + y_shape)
185178

186179
# Determine output dtype
187180
out_dtype = upcast(x.type.dtype, y.type.dtype)
@@ -249,4 +242,12 @@ def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None):
249242
y_dims = set(y.type.dims)
250243
dims = tuple(x_dims & y_dims)
251244

252-
return XDot(dims=dims, sum_result=sum_result)(x, y)
245+
result = XDot(dims=dims)(x, y)
246+
247+
if sum_result:
248+
from pytensor.xtensor.reduction import sum as xtensor_sum
249+
250+
# Sum over all remaining axes
251+
result = xtensor_sum(result, dim=...)
252+
253+
return result

pytensor/xtensor/rewriting/math.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,5 @@ def lower_dot(fgraph, node):
3737
# Perform the tensordot operation
3838
out_tensor = tensordot(x_tensor, y_tensor, axes=(x_axes, y_axes))
3939

40-
# Sum over all remaining axes if needed
41-
if node.op.sum_result:
42-
# Sum over all remaining dimensions
43-
out_tensor = out_tensor.sum(axis=None)
44-
4540
# Convert back to xtensor
4641
return [xtensor_from_tensor(out_tensor, out.type.dims)]

0 commit comments

Comments
 (0)