77 from ._lib ._typing import Array , ModuleType
88
99from ._lib import _utils
10+ from ._lib ._compat import array_namespace
1011
1112__all__ = [
1213 "atleast_nd" ,
1920]
2021
2122
22- def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType ) -> Array :
23+ def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType | None = None ) -> Array :
2324 """
2425 Recursively expand the dimension of an array to at least `ndim`.
2526
@@ -28,8 +29,8 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
2829 x : array
2930 ndim : int
3031 The minimum number of dimensions for the result.
31- xp : array_namespace
32- The standard-compatible namespace for `x`.
32+ xp : array_namespace, optional
33+ The standard-compatible namespace for `x`. Default: infer
3334
3435 Returns
3536 -------
@@ -53,13 +54,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
5354 True
5455
5556 """
57+ if xp is None :
58+ xp = array_namespace (x )
59+
5660 if x .ndim < ndim :
5761 x = xp .expand_dims (x , axis = 0 )
5862 x = atleast_nd (x , ndim = ndim , xp = xp )
5963 return x
6064
6165
62- def cov (m : Array , / , * , xp : ModuleType ) -> Array :
66+ def cov (m : Array , / , * , xp : ModuleType | None = None ) -> Array :
6367 """
6468 Estimate a covariance matrix.
6569
@@ -77,8 +81,8 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
7781 A 1-D or 2-D array containing multiple variables and observations.
7882 Each row of `m` represents a variable, and each column a single
7983 observation of all those variables.
80- xp : array_namespace
81- The standard-compatible namespace for `m`.
84+ xp : array_namespace, optional
85+ The standard-compatible namespace for `m`. Default: infer
8286
8387 Returns
8488 -------
@@ -125,6 +129,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
125129 Array(2.14413333, dtype=array_api_strict.float64)
126130
127131 """
132+ if xp is None :
133+ xp = array_namespace (m )
134+
128135 m = xp .asarray (m , copy = True )
129136 dtype = (
130137 xp .float64 if xp .isdtype (m .dtype , "integral" ) else xp .result_type (m , xp .float64 )
@@ -150,7 +157,9 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
150157 return xp .squeeze (c , axis = axes )
151158
152159
153- def create_diagonal (x : Array , / , * , offset : int = 0 , xp : ModuleType ) -> Array :
160+ def create_diagonal (
161+ x : Array , / , * , offset : int = 0 , xp : ModuleType | None = None
162+ ) -> Array :
154163 """
155164 Construct a diagonal array.
156165
@@ -162,8 +171,8 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
162171 Offset from the leading diagonal (default is ``0``).
163172 Use positive ints for diagonals above the leading diagonal,
164173 and negative ints for diagonals below the leading diagonal.
165- xp : array_namespace
166- The standard-compatible namespace for `x`.
174+ xp : array_namespace, optional
175+ The standard-compatible namespace for `x`. Default: infer
167176
168177 Returns
169178 -------
@@ -189,6 +198,9 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
189198 [0, 0, 8, 0, 0]], dtype=array_api_strict.int64)
190199
191200 """
201+ if xp is None :
202+ xp = array_namespace (x )
203+
192204 if x .ndim != 1 :
193205 err_msg = "`x` must be 1-dimensional."
194206 raise ValueError (err_msg )
@@ -200,7 +212,7 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
200212
201213
202214def expand_dims (
203- a : Array , / , * , axis : int | tuple [int , ...] = (0 ,), xp : ModuleType
215+ a : Array , / , * , axis : int | tuple [int , ...] = (0 ,), xp : ModuleType | None = None
204216) -> Array :
205217 """
206218 Expand the shape of an array.
@@ -220,8 +232,8 @@ def expand_dims(
220232 given by a positive index could also be referred to by a negative index -
221233 that will also result in an error).
222234 Default: ``(0,)``.
223- xp : array_namespace
224- The standard-compatible namespace for `a`.
235+ xp : array_namespace, optional
236+ The standard-compatible namespace for `a`. Default: infer
225237
226238 Returns
227239 -------
@@ -265,6 +277,9 @@ def expand_dims(
265277 [2]]], dtype=array_api_strict.int64)
266278
267279 """
280+ if xp is None :
281+ xp = array_namespace (a )
282+
268283 if not isinstance (axis , tuple ):
269284 axis = (axis ,)
270285 ndim = a .ndim + len (axis )
@@ -282,7 +297,7 @@ def expand_dims(
282297 return a
283298
284299
285- def kron (a : Array , b : Array , / , * , xp : ModuleType ) -> Array :
300+ def kron (a : Array , b : Array , / , * , xp : ModuleType | None = None ) -> Array :
286301 """
287302 Kronecker product of two arrays.
288303
@@ -294,8 +309,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
294309 Parameters
295310 ----------
296311 a, b : array
297- xp : array_namespace
298- The standard-compatible namespace for `a` and `b`.
312+ xp : array_namespace, optional
313+ The standard-compatible namespace for `a` and `b`. Default: infer
299314
300315 Returns
301316 -------
@@ -357,6 +372,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
357372 Array(True, dtype=array_api_strict.bool)
358373
359374 """
375+ if xp is None :
376+ xp = array_namespace (a , b )
360377
361378 b = xp .asarray (b )
362379 singletons = (1 ,) * (b .ndim - a .ndim )
@@ -390,7 +407,12 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
390407
391408
392409def setdiff1d (
393- x1 : Array , x2 : Array , / , * , assume_unique : bool = False , xp : ModuleType
410+ x1 : Array ,
411+ x2 : Array ,
412+ / ,
413+ * ,
414+ assume_unique : bool = False ,
415+ xp : ModuleType | None = None ,
394416) -> Array :
395417 """
396418 Find the set difference of two arrays.
@@ -406,8 +428,8 @@ def setdiff1d(
406428 assume_unique : bool
407429 If ``True``, the input arrays are both assumed to be unique, which
408430 can speed up the calculation. Default is ``False``.
409- xp : array_namespace
410- The standard-compatible namespace for `x1` and `x2`.
431+ xp : array_namespace, optional
432+ The standard-compatible namespace for `x1` and `x2`. Default: infer
411433
412434 Returns
413435 -------
@@ -427,6 +449,8 @@ def setdiff1d(
427449 Array([1, 2], dtype=array_api_strict.int64)
428450
429451 """
452+ if xp is None :
453+ xp = array_namespace (x1 , x2 )
430454
431455 if assume_unique :
432456 x1 = xp .reshape (x1 , (- 1 ,))
@@ -436,7 +460,7 @@ def setdiff1d(
436460 return x1 [_utils .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
437461
438462
439- def sinc (x : Array , / , * , xp : ModuleType ) -> Array :
463+ def sinc (x : Array , / , * , xp : ModuleType | None = None ) -> Array :
440464 r"""
441465 Return the normalized sinc function.
442466
@@ -456,8 +480,8 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
456480 x : array
457481 Array (possibly multi-dimensional) of values for which to calculate
458482 ``sinc(x)``. Must have a real floating point dtype.
459- xp : array_namespace
460- The standard-compatible namespace for `x`.
483+ xp : array_namespace, optional
484+ The standard-compatible namespace for `x`. Default: infer
461485
462486 Returns
463487 -------
@@ -511,6 +535,9 @@ def sinc(x: Array, /, *, xp: ModuleType) -> Array:
511535 -3.89817183e-17], dtype=array_api_strict.float64)
512536
513537 """
538+ if xp is None :
539+ xp = array_namespace (x )
540+
514541 if not xp .isdtype (x .dtype , "real floating" ):
515542 err_msg = "`x` must have a real floating data type."
516543 raise ValueError (err_msg )
0 commit comments