@@ -150,15 +150,12 @@ class XDot(XOp):
150150 ----------
151151 dims : tuple of str
152152 The dimensions to contract over. If None, will contract over all matching dimensions.
153- sum_result : bool
154- If True, sum over all remaining axes after contraction (for full contraction, e.g. dims=...).
155153 """
156154
157- __props__ = ("dims" , "sum_result" )
155+ __props__ = ("dims" ,)
158156
159- def __init__ (self , dims : Iterable [str ], sum_result : bool = False ):
157+ def __init__ (self , dims : Iterable [str ]):
160158 self .dims = dims
161- self .sum_result = sum_result
162159 super ().__init__ ()
163160
164161 def make_node (self , x , y ):
@@ -176,12 +173,8 @@ def make_node(self, x, y):
176173 ]
177174
178175 # Combine remaining dimensions
179- if self .sum_result :
180- out_dims = ()
181- out_shape = ()
182- else :
183- out_dims = tuple (x_dims + y_dims )
184- out_shape = tuple (x_shape + y_shape )
176+ out_dims = tuple (x_dims + y_dims )
177+ out_shape = tuple (x_shape + y_shape )
185178
186179 # Determine output dtype
187180 out_dtype = upcast (x .type .dtype , y .type .dtype )
@@ -249,4 +242,12 @@ def dot(x, y, dims: str | Iterable[str] | EllipsisType | None = None):
249242 y_dims = set (y .type .dims )
250243 dims = tuple (x_dims & y_dims )
251244
252- return XDot (dims = dims , sum_result = sum_result )(x , y )
245+ result = XDot (dims = dims )(x , y )
246+
247+ if sum_result :
248+ from pytensor .xtensor .reduction import sum as xtensor_sum
249+
250+ # Sum over all remaining axes
251+ result = xtensor_sum (result , dim = ...)
252+
253+ return result
0 commit comments