Skip to content

Commit 90de28c

Browse files
Merge pull request jax-ml#25335 from gnecula:export_doc_call
PiperOrigin-RevId: 704589764
2 parents 12c3057 + cc73c50 commit 90de28c

File tree

3 files changed

+54
-35
lines changed

3 files changed

+54
-35
lines changed

docs/developer.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -689,22 +689,21 @@ minimization phase.
689689
### Doctests
690690

691691
JAX uses pytest in doctest mode to test the code examples within the documentation.
692-
You can run this using
692+
You can find the up-to-date command to run doctests in
693+
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml).
694+
E.g., you can run:
693695

694696
```
695-
pytest docs
697+
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
696698
```
697699

698700
Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
699701
function docstrings will run correctly. You can run this locally using, for example:
700702

701703
```
702-
pytest --doctest-modules jax/_src/numpy/lax_numpy.py
704+
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest --doctest-modules jax/_src/numpy/lax_numpy.py
703705
```
704706

705-
Keep in mind that there are several files that are marked to be skipped when the
706-
doctest command is run on the full package; you can see the details in
707-
[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml)
708707

709708
## Type checking
710709

docs/jax.export.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ Classes
1414
.. autosummary::
1515
:toctree: _autosummary
1616

17-
Exported
18-
DisabledSafetyCheck
17+
.. autoclass:: Exported
18+
:members:
19+
20+
.. autoclass:: DisabledSafetyCheck
21+
:members:
1922

2023
Functions
2124
---------

jax/_src/export/_export.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ class Exported:
203203
_get_vjp: Callable[[Exported], Exported] | None
204204

205205
def mlir_module(self) -> str:
206+
"""A string representation of the `mlir_module_serialized`."""
206207
return xla_client._xla.mlir.deserialize_portable_artifact(self.mlir_module_serialized)
207208

208209
def __str__(self):
@@ -211,8 +212,8 @@ def __str__(self):
211212
return f"Exported(fun_name={self.fun_name}, ...)"
212213

213214
def in_shardings_jax(
214-
self,
215-
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
215+
self,
216+
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
216217
"""Creates Shardings corresponding to self.in_shardings_hlo.
217218
218219
The Exported object stores `in_shardings_hlo` as HloShardings, which are
@@ -221,38 +222,39 @@ def in_shardings_jax(
221222
`jax.device_put`.
222223
223224
Example usage:
224-
>>> from jax import export
225-
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
226-
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
227-
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
228-
... )(np.arange(jax.device_count()))
229-
>>> exp.in_shardings_hlo
230-
({devices=[8]<=[8]},)
231-
232-
# Create a mesh for running the exported object
233-
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
234-
>>>
235-
# Put the args and kwargs on the appropriate devices
236-
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
237-
... exp.in_shardings_jax(run_mesh)[0])
238-
>>> res = exp.call(run_arg)
239-
>>> res.addressable_shards
240-
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
241-
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
242-
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
243-
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
244-
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
245-
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
246-
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
247-
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
225+
226+
>>> from jax import export
227+
>>> # Prepare the exported object:
228+
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
229+
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
230+
... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
231+
... )(np.arange(jax.device_count()))
232+
>>> exp.in_shardings_hlo
233+
({devices=[8]<=[8]},)
234+
>>> # Create a mesh for running the exported object
235+
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
236+
>>> # Put the args and kwargs on the appropriate devices
237+
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
238+
... exp.in_shardings_jax(run_mesh)[0])
239+
>>> res = exp.call(run_arg)
240+
>>> res.addressable_shards
241+
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
242+
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
243+
Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
244+
Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
245+
Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
246+
Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
247+
Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
248+
Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
249+
248250
"""
249251
return tuple(_hlo_sharding_to_xla_compatible_sharding(s, mesh)
250252
for s in self.in_shardings_hlo)
251253

252254
def out_shardings_jax(
253255
self,
254256
mesh: sharding.Mesh) -> Sequence[sharding.Sharding | None]:
255-
"""Creates Shardings corresponding to self.out_shardings_hlo.
257+
"""Creates Shardings corresponding to `self.out_shardings_hlo`.
256258
257259
See documentation for in_shardings_jax.
258260
"""
@@ -289,6 +291,21 @@ def serialize(self,
289291
return serialize(self, vjp_order=vjp_order)
290292

291293
def call(self, *args, **kwargs):
294+
"""Call an exported function from a JAX program.
295+
296+
Args:
297+
args: the positional arguments to pass to the exported function. This
298+
should be a pytree of arrays with the same pytree structure as the
299+
arguments for which the function was exported.
300+
kwargs: the keyword arguments to pass to the exported function.
301+
302+
Returns: a pytree of result array, with the same structure as the
303+
results of the exported function.
304+
305+
The invocation supports reverse-mode AD, and all the features supported
306+
by exporting: shape polymorphism, multi-platform, device polymorphism.
307+
See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).
308+
"""
292309
return call_exported(self)(*args, **kwargs)
293310

294311

0 commit comments

Comments
 (0)