Skip to content

Commit 231967f

Browse files
[AutoPGLE] Explicitly ignore host callback pointers
Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE. PiperOrigin-RevId: 700289965
1 parent b6566c8 commit 231967f

File tree

5 files changed

+173
-55
lines changed

5 files changed

+173
-55
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ pytype_strict_library(
426426
name = "compiler",
427427
srcs = ["_src/compiler.py"],
428428
deps = [
429+
":cache_key",
429430
":compilation_cache_internal",
430431
":config",
431432
":mlir",

jax/_src/cache_key.py

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import copy
16+
import enum
1617
import hashlib
1718
import io
1819
import logging
@@ -62,11 +63,23 @@ def custom_hook() -> str:
6263
return ""
6364

6465

65-
def get(module: ir.Module,
66-
devices: np.ndarray,
67-
compile_options: xla_client.CompileOptions,
68-
backend: xla_client.Client,
69-
compression_algorithm: str = "zstandard") -> str:
66+
class IgnoreCallbacks(enum.IntEnum):
67+
# Do not remove any callback pointers from precompiled IR.
68+
NO = enum.auto()
69+
# Remove all callback pointers from precompiled IR.
70+
ALL = enum.auto()
71+
# Remove only custom_partitioning callback pointer from precompiled IR.
72+
CUSTOM_PARTITIONING = enum.auto()
73+
74+
75+
def get(
76+
module: ir.Module,
77+
devices: np.ndarray,
78+
compile_options: xla_client.CompileOptions,
79+
backend: xla_client.Client,
80+
compression_algorithm: str = "zstandard",
81+
ignore_callbacks: IgnoreCallbacks = IgnoreCallbacks.NO,
82+
) -> str:
7083
"""Creates a hashed string to use as a key to the compilation cache.
7184
7285
Creates a cache key that is a hex-encoded string of a unique hash based on
@@ -79,28 +92,47 @@ def get(module: ir.Module,
7992
backend: description of the platform (e.g., TPU version)
8093
compression_algorithm: a string representing the compression algorithm used
8194
for the executable before persisting in the cache
95+
ignore_callbacks: whether to remove the all callback pointer from the
96+
computation.
8297
8398
Typical return value example:
8499
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
85100
"""
86101
entries = [
87-
("computation",
88-
lambda hash_obj: _hash_computation(hash_obj, module)),
89-
("jax_lib version",
90-
lambda hash_obj: hash_obj.update(
91-
bytes(jaxlib_version_str.encode("utf-8")))),
92-
("XLA flags",
93-
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes())),
94-
("compile_options",
95-
lambda hash_obj: _hash_serialized_compile_options(
96-
hash_obj, compile_options,
97-
# In case of GPU multi-process tasks we need to strip device
98-
# assignment to use cache key as invariant between processes.
99-
strip_device_assignment=(backend.platform == "gpu"))),
100-
("accelerator_config",
101-
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend)),
102-
("compression",
103-
lambda hash_obj: _hash_string(hash_obj, compression_algorithm)),
102+
(
103+
"computation",
104+
lambda hash_obj: _hash_computation(
105+
hash_obj, module, ignore_callbacks
106+
),
107+
),
108+
(
109+
"jax_lib version",
110+
lambda hash_obj: hash_obj.update(
111+
bytes(jaxlib_version_str.encode("utf-8"))
112+
),
113+
),
114+
(
115+
"XLA flags",
116+
lambda hash_obj: _hash_xla_flags(hash_obj, get_flag_prefixes()),
117+
),
118+
(
119+
"compile_options",
120+
lambda hash_obj: _hash_serialized_compile_options(
121+
hash_obj,
122+
compile_options,
123+
# In case of GPU multi-process tasks we need to strip device
124+
# assignment to use cache key as invariant between processes.
125+
strip_device_assignment=(backend.platform == "gpu"),
126+
),
127+
),
128+
(
129+
"accelerator_config",
130+
lambda hash_obj: _hash_accelerator_config(hash_obj, devices, backend),
131+
),
132+
(
133+
"compression",
134+
lambda hash_obj: _hash_string(hash_obj, compression_algorithm),
135+
),
104136
("custom_hook", lambda hash_obj: _hash_string(hash_obj, custom_hook())),
105137
]
106138

@@ -131,45 +163,56 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
131163
)
132164

133165

134-
def _remove_custom_partitioning_ptr(m: ir.Module):
135-
"""
136-
Removes custom_partitioning callback pointer from precompiled IR.
166+
def _remove_callbacks(m: ir.Module, ignore_callbacks: IgnoreCallbacks):
167+
"""Removes callback pointers from precompiled IR.
168+
137169
Python function pointers are not deterministic across executions.
138170
"""
139171
def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
140-
if (op.name == "stablehlo.custom_call" and
141-
op.attributes["call_target_name"].value == "CustomSPMDPartitioning"):
172+
if op.name == "stablehlo.custom_call" and (
173+
(
174+
ignore_callbacks == IgnoreCallbacks.ALL
175+
and op.attributes["call_target_name"].value.endswith("callback")
176+
)
177+
or op.attributes["call_target_name"].value == "CustomSPMDPartitioning"
178+
):
142179
op.attributes["backend_config"] = ir.StringAttr.get("REMOVED")
143180
return ir.WalkResult.ADVANCE
144181

182+
if ignore_callbacks == IgnoreCallbacks.NO:
183+
return m
184+
145185
m.operation.walk(_update_bc_attribute)
146186
return m
147187

148188

149-
def _serialize_ir(m: ir.Module) -> bytes:
189+
def _serialize_ir(m: ir.Module, ignore_callbacks: IgnoreCallbacks) -> bytes:
150190
output = io.BytesIO()
151-
if config.remove_custom_partitioning_ptr_from_cache_key.value:
152-
m = _remove_custom_partitioning_ptr(type_cast(ir.Module,
153-
m.operation.clone()))
191+
if ignore_callbacks != IgnoreCallbacks.NO:
192+
m = _remove_callbacks(
193+
type_cast(ir.Module, m.operation.clone()), ignore_callbacks
194+
)
154195
m.operation.write_bytecode(file=output)
155196
return output.getvalue()
156197

157198

158-
def _canonicalize_ir(m_original: ir.Module) -> bytes:
199+
def _canonicalize_ir(
200+
m_original: ir.Module, ignore_callbacks: IgnoreCallbacks
201+
) -> bytes:
159202
with m_original.context:
160203
m = type_cast(ir.Module, m_original.operation.clone())
161204
passes = pm.PassManager.parse(
162205
"builtin.module(strip-debuginfo)"
163206
)
164207
passes.run(m.operation)
165-
return _serialize_ir(m)
208+
return _serialize_ir(m, ignore_callbacks)
166209

167210

168-
def _hash_computation(hash_obj, module):
211+
def _hash_computation(hash_obj, module, ignore_callbacks: IgnoreCallbacks):
169212
if config.compilation_cache_include_metadata_in_key.value:
170-
canonical_ir = _serialize_ir(module)
213+
canonical_ir = _serialize_ir(module, ignore_callbacks)
171214
else:
172-
canonical_ir = _canonicalize_ir(module)
215+
canonical_ir = _canonicalize_ir(module, ignore_callbacks)
173216
hash_obj.update(canonical_ir)
174217

175218

jax/_src/compilation_cache.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,21 @@ def put_executable_and_time(
267267
cache.put(cache_key, executable_and_time)
268268

269269

270-
def get_cache_key(module: ir.Module,
271-
devices: np.ndarray,
272-
compile_options,
273-
backend) -> str:
274-
return cache_key.get(module, devices, compile_options, backend,
275-
"zstandard" if zstandard is not None else "zlib")
270+
def get_cache_key(
271+
module: ir.Module,
272+
devices: np.ndarray,
273+
compile_options,
274+
backend,
275+
ignore_callbacks: cache_key.IgnoreCallbacks = cache_key.IgnoreCallbacks.NO,
276+
) -> str:
277+
return cache_key.get(
278+
module,
279+
devices,
280+
compile_options,
281+
backend,
282+
"zstandard" if zstandard is not None else "zlib",
283+
ignore_callbacks,
284+
)
276285

277286

278287
def is_initialized() -> bool:

jax/_src/compiler.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Any, Callable
2525
import warnings
2626

27+
from jax._src import cache_key as cache_key_type
2728
from jax._src import compilation_cache
2829
from jax._src import config as config
2930
from jax._src import distributed
@@ -33,8 +34,8 @@
3334
from jax._src import profiler
3435
from jax._src import traceback_util
3536
from jax._src.interpreters import mlir
36-
from jax._src.lib import xla_client as xc
3737
from jax._src.lib import version as jaxlib_version
38+
from jax._src.lib import xla_client as xc
3839
from jax._src.lib.mlir import ir
3940
import numpy as np
4041

@@ -351,8 +352,18 @@ def compile_or_get_cached(
351352
monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache')
352353

353354
try:
355+
if config.remove_custom_partitioning_ptr_from_cache_key.value:
356+
ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING
357+
else:
358+
ignore_callbacks = cache_key_type.IgnoreCallbacks.NO
359+
354360
cache_key = compilation_cache.get_cache_key(
355-
computation, devices, compile_options, backend)
361+
computation,
362+
devices,
363+
compile_options,
364+
backend,
365+
ignore_callbacks=ignore_callbacks,
366+
)
356367
except xc._xla.XlaRuntimeError as ex:
357368
logger.error("compile_or_get_cached: unable to generate cache key, "
358369
"skipping the cache: %s", ex)
@@ -385,7 +396,12 @@ def compile_or_get_cached(
385396
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
386397

387398
pgle_profiled_module_key = compilation_cache.get_cache_key(
388-
computation, devices, compile_options, backend)
399+
computation,
400+
devices,
401+
compile_options,
402+
backend,
403+
cache_key_type.IgnoreCallbacks.ALL,
404+
)
389405
compile_options.executable_build_options.fdo_profile = fdo_profile
390406

391407
if _is_executable_in_cache(backend, pgle_profiled_module_key):
@@ -493,7 +509,11 @@ def _share_fdo_profiles(
493509
compile_options.executable_build_options.fdo_profile = b""
494510
profile_key = (
495511
compilation_cache.get_cache_key(
496-
computation, devices, compile_options, backend
512+
computation,
513+
devices,
514+
compile_options,
515+
backend,
516+
cache_key_type.IgnoreCallbacks.ALL,
497517
)
498518
+ "_fdo_sync"
499519
)

tests/cache_key_test.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):
176176

177177
@custom_partitioning
178178
def _cp_add(x, y):
179-
return jax.numpy.add(x, y)
179+
return jax.numpy.add(x, y)
180180

181181
_cp_add.def_partition(
182182
infer_sharding_from_operands=_infer_sharding_from_operands,
@@ -199,14 +199,59 @@ def _cp_add(x, y):
199199
r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
200200
r'\}'
201201
)
202-
with config.remove_custom_partitioning_ptr_from_cache_key(True):
203-
with computation.context:
204-
updated_module = cache_key._remove_custom_partitioning_ptr(
205-
type_cast(ir.Module, computation.operation.clone()))
206-
bcs = [match[2] for
207-
match in re.findall(pattern, str(updated_module), re.DOTALL)]
208-
for bc in bcs:
209-
self.assertEqual(bc, "REMOVED")
202+
with computation.context:
203+
updated_module = cache_key._remove_callbacks(
204+
type_cast(ir.Module, computation.operation.clone()),
205+
ignore_callbacks=cache_key.IgnoreCallbacks.ALL,
206+
)
207+
bcs = [
208+
match[2]
209+
for match in re.findall(pattern, str(updated_module), re.DOTALL)
210+
]
211+
for bc in bcs:
212+
self.assertEqual(bc, "REMOVED")
213+
214+
compile_options = compiler.get_compile_options(
215+
num_replicas=1, num_partitions=1
216+
)
217+
backend = xla_bridge.get_backend()
218+
hash_without_callback_ptrs = cache_key.get(
219+
computation,
220+
devices,
221+
compile_options,
222+
backend,
223+
ignore_callbacks=cache_key.IgnoreCallbacks.CUSTOM_PARTITIONING,
224+
)
225+
expected_hash = cache_key.get(
226+
updated_module, devices, compile_options, backend
227+
)
228+
self.assertEqual(expected_hash, hash_without_callback_ptrs)
229+
230+
@jtu.skip_on_devices("cpu")
231+
def test_host_callbacks_ptrs_removed(self):
232+
def _host_callback(x, y):
233+
jax.debug.print("x={x[0]} y={y[0]}", x=x, y=y)
234+
235+
computation = (
236+
jax.jit(_host_callback)
237+
.lower(
238+
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
239+
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
240+
)
241+
.compiler_ir()
242+
)
243+
pattern = r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
244+
with computation.context:
245+
updated_module = cache_key._remove_callbacks(
246+
type_cast(ir.Module, computation.operation.clone()),
247+
ignore_callbacks=cache_key.IgnoreCallbacks.ALL,
248+
)
249+
bcs = [
250+
match[1]
251+
for match in re.findall(pattern, str(updated_module), re.DOTALL)
252+
]
253+
for bc in bcs:
254+
self.assertEqual(bc, "REMOVED")
210255

211256
def test_different_device_assignment(self):
212257
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()

0 commit comments

Comments
 (0)