@@ -333,10 +333,10 @@ def in_shardings_jax(
333
333
`jax.device_put`.
334
334
335
335
Example usage:
336
- >>> from jax.experimental import export
336
+ >>> from jax import export
337
337
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
338
338
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
339
- ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
339
+ ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
340
340
... )(np.arange(jax.device_count()))
341
341
>>> exp.in_shardings_hlo
342
342
({devices=[8]<=[8]},)
@@ -347,7 +347,7 @@ def in_shardings_jax(
347
347
# Put the args and kwargs on the appropriate devices
348
348
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
349
349
... exp.in_shardings_jax(run_mesh)[0])
350
- >>> res = export .call(exp) (run_arg)
350
+ >>> res = exp .call(run_arg)
351
351
>>> res.addressable_shards
352
352
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
353
353
Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
@@ -372,19 +372,53 @@ def out_shardings_jax(
372
372
for s in self .out_shardings_hlo )
373
373
374
374
def has_vjp (self ) -> bool :
375
+ """Returns if this Exported supports VJP."""
375
376
return self ._get_vjp is not None
376
377
377
378
def vjp (self ) -> Exported :
378
379
"""Gets the exported VJP.
379
380
380
381
Returns None if not available, which can happen if the Exported has been
381
- loaded from an external format, without a VJP."""
382
+ loaded from an external format without a VJP.
383
+ """
382
384
if self ._get_vjp is None :
383
385
raise ValueError ("No VJP is available" )
384
386
return self ._get_vjp (self )
385
387
388
+ def serialize (self ,
389
+ vjp_order : int = 0 ) -> bytearray :
390
+ """Serializes an Exported.
391
+
392
+ Args:
393
+ vjp_order: The maximum vjp order to include. E.g., the value 2 means that we
394
+ serialize the primal functions and two orders of the `vjp` function. This
395
+ should allow 2nd order reverse mode differentiation of the deserialized
396
+ function. i.e., `jax.grad(jax.grad(f)).`
397
+ """
398
+ # Lazy load the serialization module, since flatbuffers is an optional
399
+ # dependency.
400
+ from jax ._src .export .serialization import serialize
401
+ return serialize (self , vjp_order = vjp_order )
402
+
403
+ def call (self , * args , ** kwargs ):
404
+ return call_exported (self )(* args , ** kwargs )
405
+
406
+
407
+ def deserialize (blob : bytearray ) -> Exported :
408
+ """Deserializes an Exported.
409
+
410
+ Args:
411
+ blob: a bytearray obtained from `Exported.serialize`.
412
+ """
413
+ # Lazy load the serialization module, since flatbuffers is an optional
414
+ # dependency.
415
+ from jax ._src .export .serialization import deserialize
416
+ return deserialize (blob )
417
+
386
418
387
419
def default_lowering_platform () -> str :
420
+ """Retrieves the default lowering platform for the exporting machine.
421
+ """
388
422
# Canonicalize to turn 'gpu' into 'cuda' or 'rocm'
389
423
return xb .canonicalize_platform (jax .default_backend ())
390
424
@@ -411,14 +445,20 @@ def args_specs(
411
445
return shape_poly .symbolic_args_specs (args , polymorphic_shapes )
412
446
413
447
414
- def export (fun_jax : Callable ,
415
- * ,
416
- lowering_platforms : Sequence [str ] | None = None ,
417
- disabled_checks : Sequence [DisabledSafetyCheck ] = (),
418
- _device_assignment_for_internal_jax2tf_use_only = None ,
419
- ) -> Callable [..., Exported ]:
448
+ # TODO(necula): remove this once we remove jax.experimental.export.
449
+ def export_back_compat (
450
+ fun_jax : Callable ,
451
+ * ,
452
+ lowering_platforms : Sequence [str ] | None = None ,
453
+ disabled_checks : Sequence [DisabledSafetyCheck ] = (),
454
+ _device_assignment_for_internal_jax2tf_use_only = None ,
455
+ ) -> Callable [..., Exported ]:
420
456
"""Exports native serialization for a JAX function.
421
457
458
+ Note: this function exists only for internal usage by jax2tf and for
459
+ backwards compatibility with jax.experimental.export. Use
460
+ `jax.export` instead.
461
+
422
462
Args:
423
463
fun_jax: the function to lower and serialize.
424
464
lowering_platforms:
@@ -498,6 +538,85 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
498
538
_device_assignment_for_internal_jax2tf_use_only = _device_assignment_for_internal_jax2tf_use_only )
499
539
return do_export
500
540
541
+ def export (
542
+ fun_jit : stages .Wrapped ,
543
+ * ,
544
+ lowering_platforms : Sequence [str ] | None = None ,
545
+ disabled_checks : Sequence [DisabledSafetyCheck ] = (),
546
+ ) -> Callable [..., Exported ]:
547
+ """Exports a JAX function for persistent serialization.
548
+
549
+ Args:
550
+ fun_jit: the function to export. Should be the result of `jit`.
551
+ lowering_platforms:
552
+ Optional sequence containing a subset of 'tpu', 'cpu',
553
+ 'cuda', 'rocm'. If more than one platform is specified, then
554
+ the lowered code takes an argument specifying the platform.
555
+ If None, then use the default JAX backend.
556
+ The calling convention for multiple platforms is explained in the
557
+ `jax_export.Exported` docstring.
558
+ disabled_checks: the safety checks to disable. See docstring
559
+ of `DisabledSafetyCheck`.
560
+
561
+ Returns: a function that takes args and kwargs pytrees of jax.ShapeDtypeStruct,
562
+ or values with `.shape` and `.dtype` attributes, and returns an
563
+ `Exported`.
564
+
565
+ Usage:
566
+ >>> from jax import export
567
+ >>> exported: export.Exported = export.export(jnp.sin)(
568
+ ... np.arange(4, dtype=np.float32))
569
+
570
+ # You can inspect the Exported object
571
+ >>> exported.in_avals
572
+ (ShapedArray(float32[4]),)
573
+ >>> blob: bytearray = exported.serialize()
574
+
575
+ # The serialized bytes are safe to use in a separate process
576
+ >>> rehydrated: export.Exported = export.deserialize(blob)
577
+ >>> rehydrated.fun_name
578
+ 'sin'
579
+ >>> rehydrated.call(np.array([.1, .2, .3, .4], dtype=np.float32))
580
+ Array([0.09983342, 0.19866933, 0.29552022, 0.38941833], dtype=float32)
581
+ """
582
+ if not isinstance (fun_jit , stages .Wrapped ):
583
+ raise ValueError (
584
+ f"Function to be exported must be the result of `jit` but is: { fun_jit } " )
585
+ if lowering_platforms is not None :
586
+ actual_lowering_platforms = tuple (lowering_platforms )
587
+ else :
588
+ actual_lowering_platforms = (default_lowering_platform (),)
589
+
590
+ def do_export (* args_specs , ** kwargs_specs ) -> Exported :
591
+ # TODO: move to `lower`
592
+ symbolic_scope : tuple [shape_poly .SymbolicScope , tree_util .KeyPath ] | None = None # type: ignore[invalid-annotation,unused-ignore]
593
+ for k_path , aval in tree_util .tree_flatten_with_path ((args_specs , kwargs_specs ))[0 ]:
594
+ # Static args may have no `shape` attribute.
595
+ if not hasattr (aval , "shape" ):
596
+ continue
597
+ for d in aval .shape :
598
+ if shape_poly .is_symbolic_dim (d ):
599
+ if symbolic_scope is None :
600
+ symbolic_scope = (d .scope , k_path )
601
+ continue
602
+ symbolic_scope [0 ]._check_same_scope (
603
+ d , when = f"when exporting { util .fun_name (fun_jit )} " ,
604
+ self_descr = f"current (from { shape_poly .args_kwargs_path_to_str (symbolic_scope [1 ])} ) " ,
605
+ other_descr = shape_poly .args_kwargs_path_to_str (k_path ))
606
+
607
+ traced = fun_jit .trace ( # type: ignore
608
+ * args_specs , ** kwargs_specs ,
609
+ _experimental_lowering_parameters = mlir .LoweringParameters (
610
+ platforms = actual_lowering_platforms ,
611
+ for_export = True ,
612
+ ))
613
+ jaxpr , fun_name = traced .jaxpr , traced .fun_name
614
+ lowered = traced .lower ()
615
+ return _export_lowered (
616
+ lowered , jaxpr , fun_name ,
617
+ disabled_checks = disabled_checks )
618
+ return do_export
619
+
501
620
def _export_lowered (
502
621
lowered : stages .Lowered ,
503
622
jaxpr : core .ClosedJaxpr , fun_name : str ,
@@ -599,7 +718,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
599
718
device_assignment = device_assignment ,
600
719
apply_jit = True ,
601
720
flat_primal_fun = True )
602
- return export (fun_vjp_jax ,
721
+ return export (fun_vjp_jax , # type: ignore[arg-type]
603
722
lowering_platforms = exp_primal .lowering_platforms ,
604
723
disabled_checks = exp_primal .disabled_safety_checks )(* vjp_in_avals )
605
724
@@ -816,7 +935,7 @@ def is_token(typ, attrs):
816
935
817
936
def _check_lowering (lowering ) -> None :
818
937
if not isinstance (lowering , pxla .MeshComputation ):
819
- raise NotImplementedError (f"serialization is supported only for pjit . { lowering } " )
938
+ raise NotImplementedError (f"serialization is supported only for jit . { lowering } " )
820
939
821
940
if lowering .compile_args ["host_callbacks" ] or lowering .compile_args ["keepalive" ]:
822
941
raise NotImplementedError ("serialization of host_callbacks is not yet implemented" )
0 commit comments