Skip to content

Commit 80acf20

Browse files
committed
Rename XDot Op to Dot
1 parent e921915 commit 80acf20

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pytensor/xtensor/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def softmax(x, dim=None):
145145
return exp_x / exp_x.sum(dim=dim)
146146

147147

148-
class XDot(XOp):
148+
class Dot(XOp):
149149
"""Matrix multiplication between two XTensorVariables.
150150
151151
This operation performs matrix multiplication between two tensors, automatically
@@ -247,6 +247,6 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None):
247247
if d not in union:
248248
raise ValueError(f"Dimension {d} not found in either input")
249249

250-
result = XDot(dims=tuple(dim_set))(x, y)
250+
result = Dot(dims=tuple(dim_set))(x, y)
251251

252252
return result

pytensor/xtensor/rewriting/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from pytensor.tensor import einsum
55
from pytensor.tensor.shape import specify_shape
66
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
7-
from pytensor.xtensor.math import XDot
7+
from pytensor.xtensor.math import Dot
88
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
99

1010

1111
@register_lower_xtensor
12-
@node_rewriter(tracks=[XDot])
12+
@node_rewriter(tracks=[Dot])
1313
def lower_dot(fgraph, node):
1414
"""Rewrite XDot to tensor.dot.
1515

0 commit comments

Comments
 (0)