We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c822a8e commit 11eaa32Copy full SHA for 11eaa32
pytensor/tensor/basic.py
@@ -4380,7 +4380,7 @@ def atleast_Nd(
4380
atleast_3d = partial(atleast_Nd, n=3)
4381
4382
4383
-def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
+def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable:
4384
"""Expand the shape of an array.
4385
4386
Insert a new axis that will appear at the `axis` position in the expanded
0 commit comments