Skip to content

Commit bbaec6e

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a colocated Python program. This Python binding will be used internally in the colocated Python API implementation. The API does not yet compile the program into an executable, which will be added separately. PiperOrigin-RevId: 700443656
1 parent 6763fcf commit bbaec6e

File tree

4 files changed

+29
-5
lines changed

4 files changed

+29
-5
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,5 +1198,6 @@ pytype_library(
11981198
":util",
11991199
":xla_bridge",
12001200
"//jax/_src/lib",
1201+
"//jax/extend:ifrt_programs",
12011202
] + py_deps("numpy") + py_deps("cloudpickle"),
12021203
)

jax/experimental/colocated_python/func.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from jax._src.traceback_util import api_boundary
2929
from jax._src.util import wraps
3030
from jax.experimental.colocated_python import func_backend
31-
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize_specs
31+
from jax.experimental.colocated_python.serialization import _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
32+
from jax.extend.ifrt_programs import ifrt_programs
3233

3334
ShapeDtypeStructTree = Any # PyTree[api.ShapeDtypeStruct]
3435

@@ -141,8 +142,13 @@ def _compile_to_executable(
141142
devices: xc.DeviceList,
142143
) -> Callable[..., Any]:
143144
"""Compiles a Python function into a runtime executable."""
144-
# TODO(hyeontaek): Wrap fun as CustomCallProgram and compile it into an
145-
# executable.
145+
pickled_function = _serialize(fun)
146+
program = ifrt_programs.make_colocated_python_program(
147+
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
148+
)
149+
# TODO(hyeontaek): Compile the program and use the executable.
150+
del program
151+
146152
del name
147153
del in_specs_leaves
148154
del out_specs_leaves

tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,7 @@ jax_multiplatform_test(
13871387
srcs = ["colocated_python_test.py"],
13881388
deps = [
13891389
"//jax:experimental_colocated_python",
1390+
"//jax/extend:ifrt_programs",
13901391
],
13911392
)
13921393

tests/colocated_python_test.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from jax._src.lib import xla_extension_version # pylint: disable=g-importing-member
2323
from jax.experimental import colocated_python
2424
from jax.experimental.colocated_python import func as colocated_python_func
25+
from jax.experimental.colocated_python import serialization
26+
from jax.extend.ifrt_programs import ifrt_programs
2527
import jax.numpy as jnp
2628
import numpy as np
2729

@@ -77,8 +79,22 @@ class ColocatedPythonTest(jtu.JaxTestCase):
7779

7880
def setUp(self):
7981
super().setUp()
80-
if xla_extension_version < 290:
81-
self.skipTest("Requires xla_extension_version >= 290")
82+
if xla_extension_version < 298:
83+
self.skipTest("Requires xla_extension_version >= 298")
84+
85+
def testMakeColocatedPythonProgram(self):
86+
def add_one(x):
87+
return x + 1
88+
89+
cpu_devices = _colocated_cpu_devices(jax.local_devices())
90+
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
91+
aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)
92+
93+
pickled_function = serialization._serialize(add_one)
94+
program = ifrt_programs.make_colocated_python_program(
95+
"add_one", pickled_function, [cpu_devices[0]], [aval], [aval]
96+
)
97+
del program
8298

8399
def testSimpleFunction(self):
84100
@colocated_python.colocated_python

0 commit comments

Comments
 (0)