Skip to content

Commit 2ade7e7

Browse files
gneculajax authors
authored andcommitted
[pallas] Move the hardware_generation query in the code path that needs it
This change allows us to lower and export Pallas calls even on machines that do not have TPUs, in many cases. PiperOrigin-RevId: 641841079
1 parent af95803 commit 2ade7e7

File tree

3 files changed

+94
-5
lines changed

3 files changed

+94
-5
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,16 +384,12 @@ def as_tpu_kernel(
384384
) -> Callable[..., Any]:
385385
"""Turns an MLIR Mosaic kernel into a JAX-compatible function."""
386386
# We use jax.jit to make sure we hit the fast compilation cache.
387-
some_tpu = jax.devices(backend)[0]
388-
device_kind = some_tpu.device_kind
389-
if not device_kind.startswith("TPU v"):
390-
raise ValueError(f"Unrecognized TPU device kind: {device_kind}.")
387+
391388
if vmem_limit_bytes is not None and not isinstance(vmem_limit_bytes, int):
392389
raise ValueError(
393390
"vmem_limit_bytes must be an int: provided with a"
394391
f" {type(vmem_limit_bytes)}."
395392
)
396-
hardware_generation = int(device_kind[len("TPU v")])
397393
has_communication, has_custom_barrier = tpu.private_has_communication(
398394
module.operation
399395
)
@@ -405,6 +401,14 @@ def as_tpu_kernel(
405401
module.operation.get_asm(binary=True, enable_debug_info=True)
406402
)
407403
if needs_layout_passes and _MOSAIC_USE_PYTHON_PIPELINE.value:
404+
some_tpu = jax.devices(backend)[0]
405+
device_kind = some_tpu.device_kind
406+
if not device_kind.startswith("TPU v"):
407+
raise ValueError(
408+
f"Unrecognized TPU device kind: {device_kind}. "
409+
"tpu_custom_call cannot be lowered on a machine without TPUs "
410+
"when mosaic_use_python_pipeline=True.")
411+
hardware_generation = int(device_kind[len("TPU v")])
408412
module = _lower_tpu_kernel(module, hardware_generation)
409413
needs_hlo_passes = False
410414
needs_layout_passes = False

tests/pallas/BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,32 @@ jax_test(
330330
"//jax:pallas_gpu", # build_cleaner: keep
331331
],
332332
)
333+
334+
jax_test(
335+
name = "export_pallas_test",
336+
srcs = ["export_pallas_test.py"],
337+
config_tags_overrides = {
338+
"gpu_a100_x32": {
339+
"ondemand": False, # Include in presubmit.
340+
},
341+
},
342+
disable_configs = [
343+
"gpu",
344+
"gpu_x32",
345+
"gpu_a100",
346+
"gpu_h100",
347+
"gpu_p100",
348+
"gpu_p100_x32",
349+
"gpu_pjrt_c_api",
350+
],
351+
enable_configs = [
352+
"gpu_a100_x32",
353+
],
354+
tags = [],
355+
deps = [
356+
"//jax:pallas",
357+
"//jax:pallas_gpu", # build_cleaner: keep
358+
"//jax:pallas_tpu", # build_cleaner: keep
359+
"//jax/experimental/export",
360+
],
361+
)

tests/pallas/export_pallas_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2023 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test exporting Pallas kernels."""
16+
17+
from absl.testing import absltest
18+
import jax
19+
from jax._src import test_util as jtu
20+
from jax.experimental import export
21+
# Import mosaic for flag definitions
22+
from jax.experimental import mosaic as _ # noqa: F401
23+
from jax.experimental import pallas as pl
24+
import numpy as np
25+
26+
27+
jax.config.parse_flags_with_absl()
28+
29+
30+
class ExportTest(jtu.JaxTestCase):
31+
32+
def test_cross_platform(self):
33+
def add_vectors_kernel(x_ref, y_ref, o_ref):
34+
x, y = x_ref[...], y_ref[...]
35+
o_ref[...] = x + y
36+
37+
@jax.jit
38+
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
39+
return pl.pallas_call(add_vectors_kernel,
40+
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
41+
)(x, y)
42+
43+
a = np.arange(8)
44+
exp = export.export(
45+
add_vectors,
46+
# TODO(necula): Make this test work on GPU also
47+
lowering_platforms=["tpu"],
48+
)(a, a)
49+
50+
if jtu.device_under_test() == "tpu":
51+
res = export.call(exp)(a, a)
52+
self.assertAllClose(res, a + a)
53+
54+
55+
if __name__ == '__main__':
56+
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)