Skip to content

Commit 646fc61

Browse files
committed
tensordot formatting
1 parent 8fe4205 commit 646fc61

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

array_api_compat/common/_aliases.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -509,13 +509,14 @@ def matrix_transpose(x: Array, /, xp: Namespace) -> Array:
509509
raise ValueError("x must be at least 2-dimensional for matrix_transpose")
510510
return xp.swapaxes(x, -1, -2)
511511

512-
def tensordot(x1: Array,
513-
x2: Array,
514-
/,
515-
xp: Namespace,
516-
*,
517-
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
518-
**kwargs,
512+
def tensordot(
513+
x1: Array,
514+
x2: Array,
515+
/,
516+
xp: Namespace,
517+
*,
518+
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
519+
**kwargs,
519520
) -> Array:
520521
return xp.tensordot(x1, x2, axes=axes, **kwargs)
521522

array_api_compat/torch/_aliases.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,14 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
712712
return _vecdot(x1, x2, axis=axis)
713713

714714
# torch.tensordot uses dims instead of axes
715-
def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> Array:
715+
def tensordot(
716+
x1: Array,
717+
x2: Array,
718+
/,
719+
*,
720+
axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
721+
**kwargs,
722+
) -> Array:
716723
# Note: torch.tensordot fails with integer dtypes when there is only 1
717724
# element in the axis (https://github.com/pytorch/pytorch/issues/84530).
718725
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)

0 commit comments

Comments
 (0)