Skip to content

Commit a123d4e

Browse files
Remove autotune sharing.
xla_gpu_shard_autotuning can be used now instead and it is enabled by default. PiperOrigin-RevId: 705792463
1 parent d0f63da commit a123d4e

File tree

2 files changed

+0
-139
lines changed

2 files changed

+0
-139
lines changed

jax/_src/compiler.py

Lines changed: 0 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
from collections.abc import Sequence
2020
import logging
21-
import os
22-
import tempfile
2321
import time
2422
from typing import Any, Callable
2523
import warnings
@@ -449,22 +447,6 @@ def compile_or_get_cached(
449447
cache_key,
450448
min_device_process_id
451449
)
452-
elif (
453-
config.share_autotune_config_between_hosts.value
454-
and is_multi_process
455-
and distributed.global_state.client is not None
456-
):
457-
log_persistent_cache_miss(module_name, cache_key)
458-
return _compile_and_write_autotune_config(
459-
backend,
460-
computation,
461-
compile_options,
462-
host_callbacks,
463-
distributed.global_state.client,
464-
module_name,
465-
cache_key,
466-
min_device_process_id
467-
)
468450
else:
469451
log_persistent_cache_miss(module_name, cache_key)
470452
return _compile_and_write_cache(
@@ -608,113 +590,6 @@ def _share_fdo_profiles(
608590

609591
_share_fdo_profiles.modules_profiles = {}
610592

611-
612-
# The process with the first_process_id should compile the module and write an
613-
# autotune config to the K-V storage.
614-
def _compile_and_write_autotune_config(
615-
backend: xc.Client,
616-
computation: ir.Module,
617-
compile_options: xc.CompileOptions,
618-
host_callbacks: Sequence[Any],
619-
global_client: lib.xla_extension.DistributedRuntimeClient,
620-
module_name: str,
621-
cache_key: str,
622-
first_process_id: int
623-
) -> xc.LoadedExecutable:
624-
share_timeout = config.share_binary_between_hosts_timeout_ms.value
625-
debug_options = compile_options.executable_build_options.debug_options
626-
627-
if _compile_and_write_autotune_config.autotune_configs_dir is None:
628-
_compile_and_write_autotune_config.autotune_configs_dir = tempfile.mkdtemp()
629-
630-
autotune_tmp_file = os.path.join(
631-
_compile_and_write_autotune_config.autotune_configs_dir, cache_key
632-
)
633-
634-
if os.path.exists(autotune_tmp_file):
635-
logger.debug(
636-
"Compiling module: %s. Use existing autotune config file: %s",
637-
module_name,
638-
autotune_tmp_file,
639-
)
640-
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
641-
return _compile_and_write_cache(
642-
backend,
643-
computation,
644-
compile_options,
645-
host_callbacks,
646-
module_name,
647-
cache_key,
648-
)
649-
650-
if distributed.global_state.process_id == first_process_id:
651-
debug_options.xla_gpu_dump_autotune_results_to = autotune_tmp_file
652-
logger.debug("Process %d compiling and dumping autotune for module: %s",
653-
first_process_id, module_name)
654-
executable = _compile_and_write_cache(
655-
backend,
656-
computation,
657-
compile_options,
658-
host_callbacks,
659-
module_name,
660-
cache_key,
661-
)
662-
663-
logger.debug(
664-
"Writing autotune config for module %s to %s",
665-
module_name,
666-
autotune_tmp_file,
667-
)
668-
with open(autotune_tmp_file, "rb") as f:
669-
autotune_config = f.read()
670-
671-
autotune_config = compilation_cache.compress_executable(autotune_config)
672-
global_client.key_value_set_bytes(cache_key, autotune_config)
673-
logger.debug(
674-
"Autotune config for module %s with size %d shared by cache_key %s",
675-
module_name,
676-
len(autotune_config),
677-
cache_key,
678-
)
679-
else:
680-
logger.debug(
681-
"Compiling module %s, waiting for config to be shared by cache_key %s"
682-
"from process %d",
683-
module_name,
684-
cache_key,
685-
first_process_id
686-
)
687-
autotune_config = global_client.blocking_key_value_get_bytes(
688-
cache_key, share_timeout
689-
)
690-
691-
logger.debug(
692-
"Received autotune config for module %s of size %d",
693-
module_name,
694-
len(autotune_config),
695-
)
696-
autotune_config = compilation_cache.decompress_executable(autotune_config)
697-
with open(autotune_tmp_file, "wb") as f:
698-
f.write(autotune_config)
699-
700-
logger.debug(
701-
"Compiling module %s, using autotune config from %s",
702-
module_name,
703-
autotune_tmp_file,
704-
)
705-
debug_options.xla_gpu_load_autotune_results_from = autotune_tmp_file
706-
executable = _compile_and_write_cache(
707-
backend,
708-
computation,
709-
compile_options,
710-
host_callbacks,
711-
module_name,
712-
cache_key,
713-
)
714-
return executable
715-
716-
_compile_and_write_autotune_config.autotune_configs_dir = None
717-
718593
# The process with the first_process_id should compile the module and write it
719594
# to the K-V storage.
720595
def _compile_and_share_module(

jax/_src/config.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,20 +1169,6 @@ def _update_jax_memories_thread_local(val):
11691169
),
11701170
)
11711171

1172-
share_autotune_config_between_hosts = bool_state(
1173-
name='jax_share_autotune_config_between_hosts',
1174-
default=False,
1175-
help=(
1176-
'If set to True, the coordinator process will share autotune configs '
1177-
'other participants. This will increase overall compilation time, but '
1178-
'will lead to equal compiled modules in each process. '
1179-
'If both jax_share_binary_between_hosts and '
1180-
'jax_share_autotune_config_between_hosts are set, compiled HLO will be '
1181-
"shared when it's possible and autotune config sharing will be used "
1182-
'as a fallback.'
1183-
),
1184-
)
1185-
11861172
share_binary_between_hosts = bool_state(
11871173
name='jax_share_binary_between_hosts',
11881174
default=False,

0 commit comments

Comments
 (0)