|
21 | 21 | import google_benchmark |
22 | 22 | import jax |
23 | 23 | from jax import lax |
24 | | -from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation |
25 | | -from jax._src import core |
26 | | -from jax._src.lib import xla_client as xc |
27 | 24 | from jax._src import array |
| 25 | +from jax._src import core |
28 | 26 | from jax._src import op_shardings |
| 27 | +from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation |
| 28 | +from jax._src.lib import xla_client as xc |
29 | 29 | from jax._src.pjit import pjit_check_aval_sharding |
30 | | -from jax.experimental import pjit as pjit_lib |
31 | 30 | from jax.experimental import multihost_utils |
| 31 | +from jax.experimental import pjit as pjit_lib |
32 | 32 | import jax.numpy as jnp |
33 | 33 | import numpy as np |
34 | 34 |
|
@@ -860,29 +860,44 @@ def safe_zip(state): |
860 | 860 |
|
861 | 861 | @google_benchmark.register |
862 | 862 | def bench_make_array_from_callback_fully_replicated_sharding(state): |
863 | | - mesh = jax.sharding.Mesh( |
864 | | - np.array(jax.devices()[:8]).reshape((4, 2)), ('x', 'y')) |
865 | | - shape = (8, 2) |
866 | | - np_arr = np.arange(16).reshape(shape) |
867 | | - s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) |
| 863 | + mesh = create_mesh((4, 2), ('x', 'y'), state) |
| 864 | + if mesh is None: |
| 865 | + return |
| 866 | + input_shape = (8, 2) |
| 867 | + np_arr = np.arange(math.prod(input_shape)).reshape(input_shape) |
868 | 868 |
|
| 869 | + s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) |
869 | 870 | while state: |
870 | | - jax.make_array_from_callback(shape, s, np_arr.__getitem__) |
| 871 | + jax.make_array_from_callback(input_shape, s, np_arr.__getitem__) |
871 | 872 |
|
872 | 873 |
|
873 | 874 | @google_benchmark.register |
874 | 875 | @google_benchmark.option.unit(google_benchmark.kMillisecond) |
875 | | -def bench_make_array_from_callback_sharded(state): |
876 | | - global_mesh = create_mesh((4, 2), ('x', 'y'), state) |
| 876 | +def bench_make_array_from_callback_partially_replicated_sharding(state): |
| 877 | + mesh = create_mesh((4, 2), ('x', 'y'), state) |
| 878 | + if mesh is None: |
| 879 | + return |
877 | 880 | input_shape = (8, 2) |
878 | | - input_data = np.arange(math.prod(input_shape)).reshape(input_shape) |
| 881 | + np_arr = np.arange(math.prod(input_shape)).reshape(input_shape) |
| 882 | + |
| 883 | + s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None, 'y')) |
| 884 | + while state: |
| 885 | + jax.make_array_from_callback(input_shape, s, np_arr.__getitem__) |
879 | 886 |
|
880 | | - def callback(index): |
881 | | - return input_data[index] |
882 | 887 |
|
883 | | - s = jax.NamedSharding(global_mesh, jax.sharding.PartitionSpec('x', 'y')) |
| 888 | +@google_benchmark.register |
| 889 | +@google_benchmark.option.unit(google_benchmark.kMillisecond) |
| 890 | +def bench_make_array_from_callback_fully_sharded_sharding(state): |
| 891 | + mesh = create_mesh((4, 2), ('x', 'y'), state) |
| 892 | + if mesh is None: |
| 893 | + return |
| 894 | + input_shape = (8, 2) |
| 895 | + np_arr = np.arange(math.prod(input_shape)).reshape(input_shape) |
| 896 | + |
| 897 | + s = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('x', 'y')) |
884 | 898 | while state: |
885 | | - jax.make_array_from_callback((8, 2), s, callback) |
| 899 | + jax.make_array_from_callback(input_shape, s, np_arr.__getitem__) |
| 900 | + |
886 | 901 |
|
887 | 902 | @google_benchmark.register |
888 | 903 | @google_benchmark.option.unit(google_benchmark.kMillisecond) |
|
0 commit comments