Skip to content

Commit e6dfe8f

Browse files
[AutoPGLE] Share FDO profile even when compilation cache disabled.
PiperOrigin-RevId: 704757991
1 parent 6dbafed commit e6dfe8f

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

jax/_src/compiler.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,33 @@ def compile_or_get_cached(
348348

349349
use_compilation_cache = compilation_cache.is_cache_used(backend)
350350

351+
is_multi_process = (
352+
len({device.process_index for device in devices.flatten()}) > 1
353+
)
354+
min_device_process_id = min(
355+
devices.flatten(), key=lambda device: device.id
356+
).process_index
357+
is_auto_pgle_used = (
358+
config.enable_pgle.value and config.pgle_profiling_runs.value > 0
359+
)
360+
351361
if not use_compilation_cache:
362+
if (
363+
is_multi_process
364+
and is_auto_pgle_used
365+
and distributed.global_state.client is not None
366+
):
367+
compile_options.executable_build_options.fdo_profile = (
368+
_share_fdo_profiles(
369+
computation,
370+
devices,
371+
compile_options,
372+
backend,
373+
distributed.global_state.client,
374+
min_device_process_id,
375+
)
376+
)
377+
352378
return backend_compile(backend, computation, compile_options,
353379
host_callbacks)
354380

@@ -373,14 +399,7 @@ def compile_or_get_cached(
373399
return backend_compile(backend, computation, compile_options,
374400
host_callbacks)
375401

376-
is_multi_process = (
377-
len({device.process_index for device in devices.flatten()}) > 1
378-
)
379-
min_device_process_id = min(
380-
devices.flatten(), key=lambda device: device.id
381-
).process_index
382-
383-
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
402+
if is_auto_pgle_used:
384403
cache_key = _resolve_pgle_module_cache_key(
385404
computation,
386405
devices,

0 commit comments

Comments
 (0)