We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent baedb62 commit 861115aCopy full SHA for 861115a
jax/experimental/multihost_utils.py
@@ -75,7 +75,7 @@ def pre_jit(x):
75
return host_local_array_to_global_array(inp, global_mesh, pspec)
76
77
def post_jit(x):
78
- return np.asarray(x.addressable_data(0))
+ return jax.device_get(x.addressable_data(0))
79
80
in_tree = jax.tree.map(pre_jit, in_tree)
81
out_tree = jax.jit(_psum, out_shardings=jax.sharding.NamedSharding(
0 commit comments