Skip to content

Commit 6617a0d

Browse files
junwhanahnjax authors
authored andcommitted
Expand device_put benchmarks to run with different numbers of arrays and input types
For the upcoming batching changes for `device_put`, it is useful to benchmark `device_put` with varying numbers of arrays. PiperOrigin-RevId: 641716268
1 parent a8246ea commit 6617a0d

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

benchmarks/api_benchmark.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,29 @@ def host_local_array_to_global_array(state):
678678
(input_data, input_data), global_mesh, (in_pspec, in_pspec))
679679

680680
@google_benchmark.register
681-
def device_put(state):
682-
x = np.array(1, np.int32)
681+
@google_benchmark.option.arg_names(['num_args'])
682+
@google_benchmark.option.args([1])
683+
@google_benchmark.option.args([10])
684+
@google_benchmark.option.args([100])
685+
@google_benchmark.option.args([1000])
686+
def device_put_from_numpy_array(state):
687+
x = [np.array(1, np.int32)] * state.range(0)
683688
while state:
684-
_ = jax.device_put(x).block_until_ready()
689+
_ = jax.block_until_ready(jax.device_put(x))
690+
691+
692+
@google_benchmark.register
693+
@google_benchmark.option.arg_names(['num_args'])
694+
@google_benchmark.option.args([1])
695+
@google_benchmark.option.args([10])
696+
@google_benchmark.option.args([100])
697+
@google_benchmark.option.args([1000])
698+
def device_put_from_jax_array(state):
699+
x = [np.array(1, np.int32)] * state.range(0)
700+
x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0]))
701+
d = jax.devices()[1]
702+
while state:
703+
_ = jax.block_until_ready(jax.device_put(x, device=d))
685704

686705

687706
@google_benchmark.register

0 commit comments

Comments
 (0)