Skip to content

Commit 11eaa32

Browse files
committed
Fix expand_dims type hint
1 parent c822a8e commit 11eaa32

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4380,7 +4380,7 @@ def atleast_Nd(
43804380
atleast_3d = partial(atleast_Nd, n=3)
43814381

43824382

4383-
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
4383+
def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable:
43844384
"""Expand the shape of an array.
43854385
43864386
Insert a new axis that will appear at the `axis` position in the expanded

0 commit comments

Comments
 (0)