Skip to content

Commit 7ffaae7

Browse files
committed
Fix expand_dims type hint
1 parent 89d5366 commit 7ffaae7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4384,7 +4384,7 @@ def atleast_Nd(
43844384
atleast_3d = partial(atleast_Nd, n=3)
43854385

43864386

4387-
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
4387+
def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable:
43884388
"""Expand the shape of an array.
43894389
43904390
Insert a new axis that will appear at the `axis` position in the expanded

0 commit comments

Comments
 (0)