Skip to content

Commit 667a8cb

Browse files
committed
Use autodoc for helper functions
This still requires some work to fix cross-linking.
1 parent d557948 commit 667a8cb

File tree

5 files changed

+274
-72
lines changed

5 files changed

+274
-72
lines changed

array_api_compat/common/_helpers.py

Lines changed: 207 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,24 @@
1919
import warnings
2020

2121
def is_numpy_array(x):
22+
"""
23+
Return True if `x` is a NumPy array.
24+
25+
This function does not import NumPy if it has not already been imported
26+
and is therefore cheap to use.
27+
28+
This also returns True for `ndarray` subclasses and NumPy scalar objects.
29+
30+
See Also
31+
--------
32+
33+
array_namespace
34+
is_array_api_obj
35+
is_cupy_array
36+
is_torch_array
37+
is_dask_array
38+
is_jax_array
39+
"""
2240
# Avoid importing NumPy if it isn't already
2341
if 'numpy' not in sys.modules:
2442
return False
@@ -29,6 +47,24 @@ def is_numpy_array(x):
2947
return isinstance(x, (np.ndarray, np.generic))
3048

3149
def is_cupy_array(x):
50+
"""
51+
Return True if `x` is a CuPy array.
52+
53+
This function does not import CuPy if it has not already been imported
54+
and is therefore cheap to use.
55+
56+
This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects.
57+
58+
See Also
59+
--------
60+
61+
array_namespace
62+
is_array_api_obj
63+
is_numpy_array
64+
is_torch_array
65+
is_dask_array
66+
is_jax_array
67+
"""
3268
# Avoid importing NumPy if it isn't already
3369
if 'cupy' not in sys.modules:
3470
return False
@@ -39,6 +75,22 @@ def is_cupy_array(x):
3975
return isinstance(x, (cp.ndarray, cp.generic))
4076

4177
def is_torch_array(x):
78+
"""
79+
Return True if `x` is a PyTorch tensor.
80+
81+
This function does not import PyTorch if it has not already been imported
82+
and is therefore cheap to use.
83+
84+
See Also
85+
--------
86+
87+
array_namespace
88+
is_array_api_obj
89+
is_numpy_array
90+
is_cupy_array
91+
is_dask_array
92+
is_jax_array
93+
"""
4294
# Avoid importing torch if it isn't already
4395
if 'torch' not in sys.modules:
4496
return False
@@ -49,6 +101,22 @@ def is_torch_array(x):
49101
return isinstance(x, torch.Tensor)
50102

51103
def is_dask_array(x):
104+
"""
105+
Return True if `x` is a dask.array Array.
106+
107+
This function does not import dask if it has not already been imported
108+
and is therefore cheap to use.
109+
110+
See Also
111+
--------
112+
113+
array_namespace
114+
is_array_api_obj
115+
is_numpy_array
116+
is_cupy_array
117+
is_torch_array
118+
is_jax_array
119+
"""
52120
# Avoid importing dask if it isn't already
53121
if 'dask.array' not in sys.modules:
54122
return False
@@ -58,6 +126,23 @@ def is_dask_array(x):
58126
return isinstance(x, dask.array.Array)
59127

60128
def is_jax_array(x):
129+
"""
130+
Return True if `x` is a JAX array.
131+
132+
This function does not import JAX if it has not already been imported
133+
and is therefore cheap to use.
134+
135+
136+
See Also
137+
--------
138+
139+
array_namespace
140+
is_array_api_obj
141+
is_numpy_array
142+
is_cupy_array
143+
is_torch_array
144+
is_dask_array
145+
"""
61146
# Avoid importing jax if it isn't already
62147
if 'jax' not in sys.modules:
63148
return False
@@ -68,7 +153,17 @@ def is_jax_array(x):
68153

69154
def is_array_api_obj(x):
70155
"""
71-
Check if x is an array API compatible array object.
156+
Return True if `x` is an array API compatible array object.
157+
158+
See Also
159+
--------
160+
161+
array_namespace
162+
is_numpy_array
163+
is_cupy_array
164+
is_torch_array
165+
is_dask_array
166+
is_jax_array
72167
"""
73168
return is_numpy_array(x) \
74169
or is_cupy_array(x) \
@@ -87,17 +182,57 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
87182
"""
88183
Get the array API compatible namespace for the arrays `xs`.
89184
90-
`xs` should contain one or more arrays.
185+
Parameters
186+
----------
187+
xs: arrays
188+
one or more arrays.
189+
190+
api_version: str
191+
The newest version of the spec that you need support for (currently
192+
the compat library wrapped APIs support v2022.12).
193+
194+
Returns
195+
-------
196+
197+
out: namespace
198+
The array API compatible namespace corresponding to the arrays in `xs`.
199+
200+
Raises
201+
------
202+
TypeError
203+
If `xs` contains arrays from different array libraries or contains a
204+
non-array.
205+
91206
92-
Typical usage is
207+
Typical usage is to pass the arguments of a function to
208+
`array_namespace()` at the top of a function to get the corresponding
209+
array API namespace:
93210
94-
def your_function(x, y):
95-
xp = array_api_compat.array_namespace(x, y)
96-
# Now use xp as the array library namespace
97-
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
211+
.. code:: python
212+
213+
def your_function(x, y):
214+
xp = array_api_compat.array_namespace(x, y)
215+
# Now use xp as the array library namespace
216+
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
217+
218+
219+
Wrapped array namespaces can also be imported directly. For example,
220+
`array_namespace(np.array(...))` will return `array_api_compat.numpy`.
221+
This function will also work for any array library not wrapped by
222+
array-api-compat if it explicitly defines `__array_namespace__
223+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__
224+
(the wrapped namespace is always preferred if it exists).
225+
226+
See Also
227+
--------
228+
229+
is_array_api_obj
230+
is_numpy_array
231+
is_cupy_array
232+
is_torch_array
233+
is_dask_array
234+
is_jax_array
98235
99-
api_version should be the newest version of the spec that you need support
100-
for (currently the compat library wrapped APIs only support v2021.12).
101236
"""
102237
namespaces = set()
103238
for x in xs:
@@ -181,15 +316,33 @@ def device(x: Array, /) -> Device:
181316
"""
182317
Hardware device the array data resides on.
183318
319+
This is equivalent to `x.device` according to the `standard
320+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
321+
This helper is included because some array libraries either do not have
322+
the `device` attribute or include it with an incompatible API.
323+
184324
Parameters
185325
----------
186326
x: array
187-
array instance from NumPy or an array API compatible library.
327+
array instance from an array API compatible library.
188328
189329
Returns
190330
-------
191331
out: device
192-
a ``device`` object (see the "Device Support" section of the array API specification).
332+
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
333+
section of the array API specification).
334+
335+
Notes
336+
-----
337+
338+
For NumPy the device is always `"cpu"`. For Dask, the device is always a
339+
special `DASK_DEVICE` object.
340+
341+
See Also
342+
--------
343+
344+
to_device : Move array data to a different device.
345+
193346
"""
194347
if is_numpy_array(x):
195348
return "cpu"
@@ -262,22 +415,52 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
262415
"""
263416
Copy the array from the device on which it currently resides to the specified ``device``.
264417
418+
This is equivalent to `x.to_device(device, stream=stream)` according to
419+
the `standard
420+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__.
421+
This helper is included because some array libraries do not have the
422+
`to_device` method.
423+
265424
Parameters
266425
----------
426+
267427
x: array
268-
array instance from NumPy or an array API compatible library.
428+
array instance from an array API compatible library.
429+
269430
device: device
270-
a ``device`` object (see the "Device Support" section of the array API specification).
431+
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
432+
section of the array API specification).
433+
271434
stream: Optional[Union[int, Any]]
272-
stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
435+
stream object to use during copy. In addition to the types supported
436+
in ``array.__dlpack__``, implementations may choose to support any
437+
library-specific stream object with the caveat that any code using
438+
such an object would not be portable.
273439
274440
Returns
275441
-------
442+
276443
out: array
277-
an array with the same data and data type as ``x`` and located on the specified ``device``.
444+
an array with the same data and data type as ``x`` and located on the
445+
specified ``device``.
446+
447+
Notes
448+
-----
449+
450+
For NumPy, this function effectively does nothing since the only supported
451+
device is the CPU. For CuPy, this method supports CuPy CUDA `Device
452+
<https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Device.html>`_
453+
and `Stream
454+
<https://docs.cupy.dev/en/stable/reference/generated/cupy.cuda.Stream.html>`_
455+
objects. For PyTorch, this is the same as ``x.to(device)
456+
<https://pytorch.org/docs/stable/generated/torch.Tensor.to.html>`_ (the
457+
``stream`` argument is not supported in PyTorch).
458+
459+
See Also
460+
--------
461+
462+
device : Hardware device the array data resides on.
278463
279-
.. note::
280-
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
281464
"""
282465
if is_numpy_array(x):
283466
if stream is not None:
@@ -305,7 +488,13 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
305488

306489
def size(x):
307490
"""
308-
Return the total number of elements of x
491+
Return the total number of elements of x.
492+
493+
This is equivalent to `x.size` according to the `standard
494+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
495+
This helper is included because PyTorch defines `size` in an `incompatible
496+
way <https://pytorch.org/docs/stable/generated/torch.Tensor.size.html>`__.
497+
309498
"""
310499
if None in x.shape:
311500
return None

docs/conf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
extensions = [
2424
'myst_parser',
25+
'sphinx.ext.autodoc',
26+
'sphinx.ext.napoleon',
2527
'sphinx_copybutton',
2628
]
2729

@@ -30,9 +32,15 @@
3032

3133
myst_enable_extensions = ["dollarmath", "linkify"]
3234

35+
napoleon_use_rtype = False
36+
napoleon_use_param = False
37+
3338
# Make sphinx give errors for bad cross-references
3439
nitpicky = True
3540

41+
# Lets us use single backticks for code in RST
42+
default_role = 'code'
43+
3644
# -- Options for HTML output -------------------------------------------------
3745
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
3846

docs/helper-functions.md

Lines changed: 0 additions & 46 deletions
This file was deleted.

0 commit comments

Comments
 (0)