Skip to content

Commit 8813973

Browse files
[AutoPGLE] Cleanup compiler code.
PiperOrigin-RevId: 704741308
1 parent 263d4d1 commit 8813973

File tree

1 file changed

+103
-63
lines changed

1 file changed

+103
-63
lines changed

jax/_src/compiler.py

Lines changed: 103 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -374,60 +374,24 @@ def compile_or_get_cached(
374374
host_callbacks)
375375

376376
is_multi_process = (
377-
len({device.process_index for device in devices.flatten()}) > 1)
378-
min_device_process_id = (
379-
min(devices.flatten(), key=lambda device: device.id).process_index)
380-
381-
# When PGLE is enabled there might be 3 types of situations:
382-
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
383-
# in the persistent cache. In this case the module should be returned from
384-
# cache and PGLE should be disabled for this module. Is module is stored in
385-
# the persistent cache under the "pgle_profiled_module_key" which calculated
386-
# with replacing FDO profile with flag which identify that module were PGLE
387-
# profiled.
388-
# 2. PGLE profiled module is not in the persistent cache and the module is
389-
# getting built with an FDO profile. In this case we need to share FDO profile
390-
# with other processes and store the result under the
391-
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
392-
# module.
393-
# 3. PGLE profiled module is not in the persistent cache and the module is
394-
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
395-
# simply return the non-PGLE profiled module from the persistent cache.
396-
if (config.enable_pgle.value
397-
and config.pgle_profiling_runs.value > 0):
398-
fdo_profile = compile_options.executable_build_options.fdo_profile
399-
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
400-
401-
pgle_profiled_module_key = compilation_cache.get_cache_key(
377+
len({device.process_index for device in devices.flatten()}) > 1
378+
)
379+
min_device_process_id = min(
380+
devices.flatten(), key=lambda device: device.id
381+
).process_index
382+
383+
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
384+
cache_key = _resolve_pgle_module_cache_key(
402385
computation,
403386
devices,
404387
compile_options,
405388
backend,
406-
cache_key_type.IgnoreCallbacks.ALL,
389+
pgle_profiler,
390+
is_multi_process,
391+
cache_key,
392+
module_name,
393+
min_device_process_id,
407394
)
408-
compile_options.executable_build_options.fdo_profile = fdo_profile
409-
410-
if _is_executable_in_cache(backend, pgle_profiled_module_key):
411-
# Load PGLE profiled module from the persistent cache.
412-
cache_key = pgle_profiled_module_key
413-
if pgle_profiler is not None:
414-
pgle_profiler.disable()
415-
elif fdo_profile is not None and len(fdo_profile) > 0:
416-
# Store module under PGLE profiled module cache key.
417-
cache_key = pgle_profiled_module_key
418-
if is_multi_process and distributed.global_state.client is not None:
419-
compile_options.executable_build_options.fdo_profile = _share_fdo_profiles(
420-
computation, devices, compile_options, backend,
421-
distributed.global_state.client,
422-
min_device_process_id
423-
)
424-
else:
425-
compile_options.executable_build_options.fdo_profile = fdo_profile
426-
logger.debug(
427-
"Compiling module %s with FDO profile: %s",
428-
module_name,
429-
compile_options.executable_build_options.fdo_profile,
430-
)
431395

432396
cache_retrieval_start = time.monotonic()
433397
retrieved_executable, retrieved_compile_time = _cache_read(
@@ -493,6 +457,75 @@ def compile_or_get_cached(
493457
cache_key,
494458
)
495459

460+
461+
# When PGLE is enabled there might be 3 types of situations:
462+
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
463+
# in the persistent cache. In this case the module should be returned from
464+
# cache and PGLE should be disabled for this module. Is module is stored in
465+
# the persistent cache under the "pgle_profiled_module_key" which calculated
466+
# with replacing FDO profile with flag which identify that module were PGLE
467+
# profiled.
468+
# 2. PGLE profiled module is not in the persistent cache and the module is
469+
# getting built with an FDO profile. In this case we need to share FDO profile
470+
# with other processes and store the result under the
471+
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
472+
# module.
473+
# 3. PGLE profiled module is not in the persistent cache and the module is
474+
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
475+
# simply return the non-PGLE profiled module from the persistent cache.
476+
def _resolve_pgle_module_cache_key(
477+
computation: ir.Module,
478+
devices: np.ndarray,
479+
compile_options: xc.CompileOptions,
480+
backend: xc.Client,
481+
pgle_profiler: profiler.PGLEProfiler | None,
482+
is_multi_process: bool,
483+
cache_key: str,
484+
module_name: str,
485+
min_device_process_id: int,
486+
) -> str:
487+
fdo_profile = compile_options.executable_build_options.fdo_profile
488+
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
489+
490+
pgle_profiled_module_key = compilation_cache.get_cache_key(
491+
computation,
492+
devices,
493+
compile_options,
494+
backend,
495+
cache_key_type.IgnoreCallbacks.ALL,
496+
)
497+
compile_options.executable_build_options.fdo_profile = fdo_profile
498+
499+
result_key = cache_key
500+
if _is_executable_in_cache(backend, pgle_profiled_module_key):
501+
# Load PGLE profiled module from the persistent cache.
502+
result_key = pgle_profiled_module_key
503+
if pgle_profiler is not None:
504+
pgle_profiler.disable()
505+
elif fdo_profile is not None and len(fdo_profile) > 0:
506+
# Store module under PGLE profiled module cache key.
507+
result_key = pgle_profiled_module_key
508+
if is_multi_process and distributed.global_state.client is not None:
509+
compile_options.executable_build_options.fdo_profile = (
510+
_share_fdo_profiles(
511+
computation,
512+
devices,
513+
compile_options,
514+
backend,
515+
distributed.global_state.client,
516+
min_device_process_id,
517+
)
518+
)
519+
else:
520+
compile_options.executable_build_options.fdo_profile = fdo_profile
521+
logger.debug(
522+
"Compiling module %s with FDO profile of length %d",
523+
module_name,
524+
len(compile_options.executable_build_options.fdo_profile),
525+
)
526+
return result_key
527+
528+
496529
# The process that has the lowest device ID should share FDO profile before
497530
# compilation with other processes.
498531
def _share_fdo_profiles(
@@ -510,32 +543,39 @@ def _share_fdo_profiles(
510543
return fdo_profile
511544

512545
compile_options.executable_build_options.fdo_profile = b""
513-
profile_key = (
514-
compilation_cache.get_cache_key(
515-
computation,
516-
devices,
517-
compile_options,
518-
backend,
519-
cache_key_type.IgnoreCallbacks.ALL,
520-
)
521-
+ "_fdo_sync"
522-
)
546+
try:
547+
profile_key = (
548+
compilation_cache.get_cache_key(
549+
computation,
550+
devices,
551+
compile_options,
552+
backend,
553+
cache_key_type.IgnoreCallbacks.ALL,
554+
)
555+
+ "_fdo_sync"
556+
)
557+
except xc._xla.XlaRuntimeError as ex:
558+
logger.error(
559+
"compile_or_get_cached: unable to generate cache key, "
560+
"skipping the fdo profile sharing: %s",
561+
ex,
562+
)
563+
return fdo_profile
564+
523565
if profile_key in _share_fdo_profiles.modules_profiles:
524566
return _share_fdo_profiles.modules_profiles[profile_key]
525567

526568
share_timeout = config.share_binary_between_hosts_timeout_ms.value
527569
if distributed.global_state.process_id == min_process_id:
528570
logger.debug(
529-
"Sharing FDO profile: %s. For module %s. Process %d.",
530-
fdo_profile,
571+
"Module %s. Sharing FDO profile. Process %d.",
531572
module_name,
532573
min_process_id,
533574
)
534575
global_client.key_value_set_bytes(profile_key, fdo_profile)
535576
else:
536577
logger.debug(
537-
"Waiting for FDO profile: %s. For module %s. Should be set by process %d.",
538-
fdo_profile,
578+
"Module %s. Waiting for FDO profile which should be set by process %d.",
539579
module_name,
540580
min_process_id,
541581
)

0 commit comments

Comments
 (0)