We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 42a7adb commit 9452257Copy full SHA for 9452257
pytensor/tensor/basic.py
@@ -4369,7 +4369,7 @@ def atleast_Nd(
4369
atleast_3d = partial(atleast_Nd, n=3)
4370
4371
4372
-def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
+def expand_dims(a: "TensorLike", axis: Sequence[int] | int) -> TensorVariable:
4373
"""Expand the shape of an array.
4374
4375
Insert a new axis that will appear at the `axis` position in the expanded
0 commit comments