We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 89d5366 commit 7ffaae7Copy full SHA for 7ffaae7
pytensor/tensor/basic.py
@@ -4384,7 +4384,7 @@ def atleast_Nd(
4384
atleast_3d = partial(atleast_Nd, n=3)
4385
4386
4387
-def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
+def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable:
4388
"""Expand the shape of an array.
4389
4390
Insert a new axis that will appear at the `axis` position in the expanded
0 commit comments