Skip to content

Commit 3442f11

Browse files
committed
Define precise overloads for Array.reshape
1 parent 2ecf373 commit 3442f11

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

pytato/array.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@
199199
TYPE_CHECKING,
200200
Any,
201201
ClassVar,
202+
Literal,
202203
Protocol,
203204
TypeAlias,
204205
TypeVar,
@@ -246,6 +247,8 @@
246247
ShapeType: TypeAlias = tuple[ShapeComponent, ...]
247248
ConvertibleToShape: TypeAlias = "ShapeComponent | Sequence[ShapeComponent]"
248249

250+
OrderCF: TypeAlias = Literal["C"] | Literal["F"]
251+
249252
# }}}
250253

251254

@@ -971,7 +974,13 @@ def astype(self, dtype: DTypeLike) -> Array:
971974
dtype=dtype,
972975
)
973976

974-
def reshape(self, *shape: int | Sequence[int], order: str = "C") -> Array:
977+
@overload
978+
def reshape(self, *shape: int, order: OrderCF = "C") -> Array: ...
979+
980+
@overload
981+
def reshape(self, shape: tuple[int, ...], /, *, order: OrderCF = "C") -> Array: ...
982+
983+
def reshape(self, *shape: int | Sequence[int], order: OrderCF = "C") -> Array:
975984
import pytato as pt
976985
if len(shape) == 0:
977986
raise TypeError("reshape takes at least one argument (0 given)")

0 commit comments

Comments
 (0)