Skip to content

Commit e9d38b1

Browse files
committed
TYP: annotate kron and expand_dims
1 parent e9077f5 commit e9d38b1

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/array_api_extra/_funcs.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
__all__ = ["atleast_nd"]
99

1010

11-
def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
11+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
1212
"""
1313
Recursively expand the dimension of an array to at least `ndim`.
1414
@@ -48,21 +48,24 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
4848
return x
4949

5050

51-
def expand_dims(a: Array, *, axis: tuple[int] = (0,), xp: ModuleType):
51+
def expand_dims(
52+
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
53+
) -> Array:
5254
"""
5355
Expand the shape of an array.
5456
55-
Insert a new axis that will appear at the `axis` position in the expanded
56-
array shape.
57+
Insert (a) new axis/axes that will appear at the position(s) specified by
58+
`axis` in the expanded array shape.
5759
58-
This is ``xp.expand_dims`` for ``axis`` an int *or a tuple of ints*.
60+
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
5961
Equivalent to ``numpy.expand_dims`` for NumPy arrays.
6062
6163
Parameters
6264
----------
6365
a : array
6466
axis : int or tuple of ints
6567
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
68+
Default: ``(0,)``.
6669
xp : array_namespace
6770
The standard-compatible namespace for `a`.
6871
@@ -120,7 +123,7 @@ def expand_dims(a: Array, *, axis: tuple[int] = (0,), xp: ModuleType):
120123
return a
121124

122125

123-
def kron(a: Array, b: Array, *, xp: ModuleType):
126+
def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
124127
"""
125128
Kronecker product of two arrays.
126129

0 commit comments

Comments
 (0)