Skip to content

Commit 2912a1c

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[JAX] Fix colocated_python.colocated_cpu_devices() for multi-host CPU setups
On McJAX, `colocated_python.colocated_cpu_devices()` was querying only local CPU devices when finding colocated devices for (global) TPU devices. The API should have queried all CPU devices regardless of their addressability. PiperOrigin-RevId: 833889551
1 parent f29296e commit 2912a1c

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

jax/experimental/colocated_python/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _colocated_cpu_devices_cached_fallback_to_cpu_backend(
115115
else:
116116
# PjRt-IFRT on a non-CPU platform currently defines CPU devices on a separae
117117
# CPU backend.
118-
cpu_backend_devices = jax.local_devices(backend="cpu")
118+
cpu_backend_devices = jax.devices(backend="cpu")
119119
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
120120

121121
available_devices = devices[: min(len(cpu_backend_devices), len(devices))]

tests/multiprocess/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ jax_multiprocess_test(
7878
],
7979
)
8080

81+
jax_multiprocess_test(
82+
name = "colocated_python_test",
83+
srcs = ["colocated_python_test.py"],
84+
disable_configs = [
85+
# This config has two cores per chip, and JAX distributed does not get
86+
# the correct number of logical devices per host.
87+
"tpu_v3_x4",
88+
],
89+
enable_backends = [
90+
"gpu",
91+
"tpu",
92+
],
93+
main = "colocated_python_test.py",
94+
deps = [
95+
"//jax:experimental_colocated_python",
96+
"//jax/_src:test_multiprocess",
97+
],
98+
)
99+
81100
jax_multiprocess_test(
82101
name = "device_id_test",
83102
srcs = ["device_id_test.py"],
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Multihost tests for jax.Array."""
16+
17+
import jax
18+
from jax._src import test_multiprocess as jt_multiprocess
19+
from jax._src import test_util as jtu
20+
from jax.experimental import colocated_python
21+
import numpy as np
22+
23+
24+
class ColocatedPythonTestMultiHost(jt_multiprocess.MultiProcessTest):
25+
26+
def setUp(self):
27+
super().setUp()
28+
jtu.request_cpu_devices(jax.local_device_count())
29+
30+
def test_colocated_cpu_devices(self):
31+
if jax.device_count() % 2 == 0:
32+
mesh_shape = (2, jax.device_count() // 2)
33+
else:
34+
mesh_shape = (1, jax.device_count())
35+
mesh = jax.make_mesh(mesh_shape, ("x", "y"))
36+
cpu_mesh1 = colocated_python.colocated_cpu_devices(mesh)
37+
38+
cpu_devices = colocated_python.colocated_cpu_devices(mesh.devices.flat)
39+
cpu_mesh2 = jax.make_mesh(mesh_shape, ("x", "y"), devices=cpu_devices)
40+
self.assertEqual(cpu_mesh1, cpu_mesh2)
41+
42+
def test_simple_function(self):
43+
@colocated_python.colocated_python
44+
def add_one(x):
45+
return jax.make_array_from_single_device_arrays(
46+
x.shape, x.sharding, [s.data + 1 for s in x.addressable_shards])
47+
48+
mesh = jax.make_mesh((jax.device_count(),), ("x",))
49+
cpu_mesh = colocated_python.colocated_cpu_devices(mesh)
50+
cpu_sharding = jax.NamedSharding(cpu_mesh, jax.P("x"))
51+
52+
x = np.arange(cpu_mesh.size)
53+
x = jax.device_put(x, cpu_sharding)
54+
55+
out = add_one(x)
56+
57+
out = jax.jit(lambda x: x,
58+
out_shardings=jax.NamedSharding(cpu_mesh, jax.P()))(out)
59+
out = jax.device_get(out)
60+
61+
np.testing.assert_equal(out, np.arange(cpu_mesh.size) + 1)
62+
63+
64+
if __name__ == "__main__":
65+
jt_multiprocess.main()

0 commit comments

Comments
 (0)