Skip to content

Commit 7214a3a

Browse files
[AutoPGLE] Add multi-process test case
PiperOrigin-RevId: 703031689
1 parent 8163e74 commit 7214a3a

File tree

1 file changed

+37
-28
lines changed

1 file changed

+37
-28
lines changed

jax/_src/cache_key.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,38 @@ def _hash_accelerator_config(hash_obj, accelerators: np.ndarray, backend):
238238
_hash_devices(hash_obj, accelerators)
239239
_hash_platform(hash_obj, backend)
240240

241+
# LINT.IfChange(xla_flags)
242+
xla_flags_to_exclude_from_cache_key = [
243+
"--xla_dump_compress_protos",
244+
"--xla_dump_module_metadata",
245+
"--xla_dump_max_hlo_modules",
246+
"--xla_dump_include_timestamp",
247+
"--xla_dump_hlo_pass_re",
248+
"--xla_dump_hlo_module_re",
249+
"--xla_dump_hlo_snapshots",
250+
"--xla_dump_fusion_visualization",
251+
"--xla_dump_hlo_as_url",
252+
"--xla_dump_hlo_as_proto",
253+
"--xla_dump_hlo_as_text",
254+
"--xla_dump_hlo_as_long_text",
255+
"--xla_dump_hlo_as_html",
256+
"--xla_dump_hlo_as_dot",
257+
"--xla_dump_to",
258+
"--xla_force_host_platform_device_count",
259+
"--xla_dump_disable_metadata",
260+
"--xla_dump_hlo_pipeline_re",
261+
"--xla_tpu_sdc_checker_streamz_metric",
262+
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
263+
"--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks",
264+
"--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present",
265+
"--xla_gpu_cuda_data_dir",
266+
"--xla_gpu_experimental_autotune_cache_mode",
267+
]
268+
269+
env_override_flags_to_exclude_from_cache_key = {
270+
x.strip("-") for x in xla_flags_to_exclude_from_cache_key
271+
}
272+
# LINT.ThenChange(:debug_options)
241273

242274
def _hash_serialized_compile_options(hash_obj, compile_options_obj,
243275
strip_device_assignment=False):
@@ -284,6 +316,11 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
284316
debug_options.xla_gpu_cuda_data_dir = ""
285317
# LINT.ThenChange(:xla_flags)
286318

319+
compile_options_copy.env_option_overrides = [
320+
flag_value
321+
for flag_value in compile_options_copy.env_option_overrides
322+
if flag_value[0] not in env_override_flags_to_exclude_from_cache_key
323+
]
287324
if strip_device_assignment and compile_options_copy.device_assignment:
288325
replica_count = compile_options_copy.device_assignment.replica_count()
289326
computation_count = compile_options_copy.device_assignment.computation_count()
@@ -301,34 +338,6 @@ def _hash_platform(hash_obj, backend):
301338

302339

303340
def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]):
304-
# LINT.IfChange(xla_flags)
305-
xla_flags_to_exclude_from_cache_key = [
306-
"--xla_dump_compress_protos",
307-
"--xla_dump_module_metadata",
308-
"--xla_dump_max_hlo_modules",
309-
"--xla_dump_include_timestamp",
310-
"--xla_dump_hlo_pass_re",
311-
"--xla_dump_hlo_module_re",
312-
"--xla_dump_hlo_snapshots",
313-
"--xla_dump_fusion_visualization",
314-
"--xla_dump_hlo_as_url",
315-
"--xla_dump_hlo_as_proto",
316-
"--xla_dump_hlo_as_text",
317-
"--xla_dump_hlo_as_long_text",
318-
"--xla_dump_hlo_as_html",
319-
"--xla_dump_hlo_as_dot",
320-
"--xla_dump_to",
321-
"--xla_force_host_platform_device_count",
322-
"--xla_dump_disable_metadata",
323-
"--xla_dump_hlo_pipeline_re",
324-
"--xla_tpu_sdc_checker_streamz_metric",
325-
"--xla_tpu_sdc_checker_enable_sdc_event_callbacks",
326-
"--xla_tpu_sdc_checker_enable_coresweep_ng_callbacks",
327-
"--xla_tpu_sdc_checker_no_logging_if_callbacks_are_present",
328-
"--xla_gpu_cuda_data_dir",
329-
]
330-
# LINT.ThenChange(:debug_options)
331-
332341
xla_flags = []
333342

334343
xla_flags_env_var = os.getenv("XLA_FLAGS")

0 commit comments

Comments
 (0)