@@ -203,6 +203,7 @@ class Exported:
203203 _get_vjp : Callable [[Exported ], Exported ] | None
204204
205205 def mlir_module (self ) -> str :
206+ """A string representation of the `mlir_module_serialized`."""
206207 return xla_client ._xla .mlir .deserialize_portable_artifact (self .mlir_module_serialized )
207208
208209 def __str__ (self ):
@@ -211,8 +212,8 @@ def __str__(self):
211212 return f"Exported(fun_name={ self .fun_name } , ...)"
212213
213214 def in_shardings_jax (
214- self ,
215- mesh : sharding .Mesh ) -> Sequence [sharding .Sharding | None ]:
215+ self ,
216+ mesh : sharding .Mesh ) -> Sequence [sharding .Sharding | None ]:
216217 """Creates Shardings corresponding to self.in_shardings_hlo.
217218
218219 The Exported object stores `in_shardings_hlo` as HloShardings, which are
@@ -221,38 +222,39 @@ def in_shardings_jax(
221222 `jax.device_put`.
222223
223224 Example usage:
224- >>> from jax import export
225- >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
226- >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
227- ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
228- ... )(np.arange(jax.device_count()))
229- >>> exp.in_shardings_hlo
230- ({devices=[8]<=[8]},)
231-
232- # Create a mesh for running the exported object
233- >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
234- >>>
235- # Put the args and kwargs on the appropriate devices
236- >>> run_arg = jax.device_put(np.arange(jax.device_count()),
237- ... exp.in_shardings_jax(run_mesh)[0])
238- >>> res = exp.call(run_arg)
239- >>> res.addressable_shards
240- [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
241- Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
242- Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
243- Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
244- Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
245- Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
246- Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
247- Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
225+
226+ >>> from jax import export
227+ >>> # Prepare the exported object:
228+ >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
229+ >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
230+ ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
231+ ... )(np.arange(jax.device_count()))
232+ >>> exp.in_shardings_hlo
233+ ({devices=[8]<=[8]},)
234+ >>> # Create a mesh for running the exported object
235+ >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
236+ >>> # Put the args and kwargs on the appropriate devices
237+ >>> run_arg = jax.device_put(np.arange(jax.device_count()),
238+ ... exp.in_shardings_jax(run_mesh)[0])
239+ >>> res = exp.call(run_arg)
240+ >>> res.addressable_shards
241+ [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
242+ Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
243+ Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
244+ Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
245+ Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
246+ Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
247+ Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
248+ Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
249+
248250 """
249251 return tuple (_hlo_sharding_to_xla_compatible_sharding (s , mesh )
250252 for s in self .in_shardings_hlo )
251253
252254 def out_shardings_jax (
253255 self ,
254256 mesh : sharding .Mesh ) -> Sequence [sharding .Sharding | None ]:
255- """Creates Shardings corresponding to self.out_shardings_hlo.
257+ """Creates Shardings corresponding to ` self.out_shardings_hlo` .
256258
257259 See documentation for in_shardings_jax.
258260 """
@@ -289,6 +291,21 @@ def serialize(self,
289291 return serialize (self , vjp_order = vjp_order )
290292
291293 def call (self , * args , ** kwargs ):
294+ """Call an exported function from a JAX program.
295+
296+ Args:
297+ args: the positional arguments to pass to the exported function. This
298+ should be a pytree of arrays with the same pytree structure as the
299+ arguments for which the function was exported.
300+ kwargs: the keyword arguments to pass to the exported function.
301+
302+ Returns: a pytree of result array, with the same structure as the
303+ results of the exported function.
304+
305+ The invocation supports reverse-mode AD, and all the features supported
306+ by exporting: shape polymorphism, multi-platform, device polymorphism.
307+ See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).
308+ """
292309 return call_exported (self )(* args , ** kwargs )
293310
294311
0 commit comments