Skip to content

Commit 6dc4dee

Browse files
Merge pull request #182 from ROCm/ci-upstream-sync-55_1
CI: 12/10/24 upstream sync
2 parents f04f164 + 20c5c71 commit 6dc4dee

33 files changed

+949
-741
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
1212

1313
## jax 0.4.38
1414

15+
* Deprecations
16+
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
17+
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
18+
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
19+
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
20+
{mod}`jax.extend` for information on the compatibility guarantees of these
21+
semi-public extensions.
22+
1523
## jax 0.4.37 (Dec 9, 2024)
1624

1725
This is a patch release of jax 0.4.36. Only "jax" was released at this version.

docs/contributor_guide.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,3 @@ some of JAX's (extensible) internals.
2525

2626
autodidax
2727
jep/index
28-
jax_internal_api

docs/jax.extend.core.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
``jax.extend.core`` module
2+
==========================
3+
4+
.. automodule:: jax.extend.core
5+
6+
.. autosummary::
7+
:toctree: _autosummary
8+
9+
ClosedJaxpr
10+
Jaxpr
11+
JaxprEqn
12+
Literal
13+
Primitive
14+
Token
15+
Var
16+
array_types
17+
jaxpr_as_fun
18+
primitives

docs/jax.extend.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Modules
1111
.. toctree::
1212
:maxdepth: 1
1313

14+
jax.extend.core
1415
jax.extend.ffi
1516
jax.extend.linear_util
1617
jax.extend.mlir

docs/jax_internal_api.rst

Lines changed: 0 additions & 14 deletions
This file was deleted.

jax/_src/compiler.py

Lines changed: 123 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,33 @@ def compile_or_get_cached(
348348

349349
use_compilation_cache = compilation_cache.is_cache_used(backend)
350350

351+
is_multi_process = (
352+
len({device.process_index for device in devices.flatten()}) > 1
353+
)
354+
min_device_process_id = min(
355+
devices.flatten(), key=lambda device: device.id
356+
).process_index
357+
is_auto_pgle_used = (
358+
config.enable_pgle.value and config.pgle_profiling_runs.value > 0
359+
)
360+
351361
if not use_compilation_cache:
362+
if (
363+
is_multi_process
364+
and is_auto_pgle_used
365+
and distributed.global_state.client is not None
366+
):
367+
compile_options.executable_build_options.fdo_profile = (
368+
_share_fdo_profiles(
369+
computation,
370+
devices,
371+
compile_options,
372+
backend,
373+
distributed.global_state.client,
374+
min_device_process_id,
375+
)
376+
)
377+
352378
return backend_compile(backend, computation, compile_options,
353379
host_callbacks)
354380

@@ -373,61 +399,18 @@ def compile_or_get_cached(
373399
return backend_compile(backend, computation, compile_options,
374400
host_callbacks)
375401

376-
is_multi_process = (
377-
len({device.process_index for device in devices.flatten()}) > 1)
378-
min_device_process_id = (
379-
min(devices.flatten(), key=lambda device: device.id).process_index)
380-
381-
# When PGLE is enabled there might be 3 types of situations:
382-
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
383-
# in the persistent cache. In this case the module should be returned from
384-
# cache and PGLE should be disabled for this module. Is module is stored in
385-
# the persistent cache under the "pgle_profiled_module_key" which calculated
386-
# with replacing FDO profile with flag which identify that module were PGLE
387-
# profiled.
388-
# 2. PGLE profiled module is not in the persistent cache and the module is
389-
# getting built with an FDO profile. In this case we need to share FDO profile
390-
# with other processes and store the result under the
391-
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
392-
# module.
393-
# 3. PGLE profiled module is not in the persistent cache and the module is
394-
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
395-
# simply return the non-PGLE profiled module from the persistent cache.
396-
if (config.enable_pgle.value
397-
and config.pgle_profiling_runs.value > 0):
398-
fdo_profile = compile_options.executable_build_options.fdo_profile
399-
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
400-
401-
pgle_profiled_module_key = compilation_cache.get_cache_key(
402+
if is_auto_pgle_used:
403+
cache_key = _resolve_pgle_module_cache_key(
402404
computation,
403405
devices,
404406
compile_options,
405407
backend,
406-
cache_key_type.IgnoreCallbacks.ALL,
408+
pgle_profiler,
409+
is_multi_process,
410+
cache_key,
411+
module_name,
412+
min_device_process_id,
407413
)
408-
compile_options.executable_build_options.fdo_profile = fdo_profile
409-
410-
if _is_executable_in_cache(backend, pgle_profiled_module_key):
411-
# Load PGLE profiled module from the persistent cache.
412-
cache_key = pgle_profiled_module_key
413-
if pgle_profiler is not None:
414-
pgle_profiler.disable()
415-
elif fdo_profile is not None and len(fdo_profile) > 0:
416-
# Store module under PGLE profiled module cache key.
417-
cache_key = pgle_profiled_module_key
418-
if is_multi_process and distributed.global_state.client is not None:
419-
compile_options.executable_build_options.fdo_profile = _share_fdo_profiles(
420-
computation, devices, compile_options, backend,
421-
distributed.global_state.client,
422-
min_device_process_id
423-
)
424-
else:
425-
compile_options.executable_build_options.fdo_profile = fdo_profile
426-
logger.debug(
427-
"Compiling module %s with FDO profile: %s",
428-
module_name,
429-
compile_options.executable_build_options.fdo_profile,
430-
)
431414

432415
cache_retrieval_start = time.monotonic()
433416
retrieved_executable, retrieved_compile_time = _cache_read(
@@ -493,6 +476,75 @@ def compile_or_get_cached(
493476
cache_key,
494477
)
495478

479+
480+
# When PGLE is enabled there might be 3 types of situations:
481+
# 1. PGLE profiled module (the one which was recompiled with FDO profile) is
482+
# in the persistent cache. In this case the module should be returned from
483+
# cache and PGLE should be disabled for this module. Is module is stored in
484+
# the persistent cache under the "pgle_profiled_module_key" which calculated
485+
# with replacing FDO profile with flag which identify that module were PGLE
486+
# profiled.
487+
# 2. PGLE profiled module is not in the persistent cache and the module is
488+
# getting built with an FDO profile. In this case we need to share FDO profile
489+
# with other processes and store the result under the
490+
# "pgle_profiled_module_key" so later in case 1 we will be able to find the
491+
# module.
492+
# 3. PGLE profiled module is not in the persistent cache and the module is
493+
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
494+
# simply return the non-PGLE profiled module from the persistent cache.
495+
def _resolve_pgle_module_cache_key(
496+
computation: ir.Module,
497+
devices: np.ndarray,
498+
compile_options: xc.CompileOptions,
499+
backend: xc.Client,
500+
pgle_profiler: profiler.PGLEProfiler | None,
501+
is_multi_process: bool,
502+
cache_key: str,
503+
module_name: str,
504+
min_device_process_id: int,
505+
) -> str:
506+
fdo_profile = compile_options.executable_build_options.fdo_profile
507+
compile_options.executable_build_options.fdo_profile = b"pgle profiled"
508+
509+
pgle_profiled_module_key = compilation_cache.get_cache_key(
510+
computation,
511+
devices,
512+
compile_options,
513+
backend,
514+
cache_key_type.IgnoreCallbacks.ALL,
515+
)
516+
compile_options.executable_build_options.fdo_profile = fdo_profile
517+
518+
result_key = cache_key
519+
if _is_executable_in_cache(backend, pgle_profiled_module_key):
520+
# Load PGLE profiled module from the persistent cache.
521+
result_key = pgle_profiled_module_key
522+
if pgle_profiler is not None:
523+
pgle_profiler.disable()
524+
elif fdo_profile is not None and len(fdo_profile) > 0:
525+
# Store module under PGLE profiled module cache key.
526+
result_key = pgle_profiled_module_key
527+
if is_multi_process and distributed.global_state.client is not None:
528+
compile_options.executable_build_options.fdo_profile = (
529+
_share_fdo_profiles(
530+
computation,
531+
devices,
532+
compile_options,
533+
backend,
534+
distributed.global_state.client,
535+
min_device_process_id,
536+
)
537+
)
538+
else:
539+
compile_options.executable_build_options.fdo_profile = fdo_profile
540+
logger.debug(
541+
"Compiling module %s with FDO profile of length %d",
542+
module_name,
543+
len(compile_options.executable_build_options.fdo_profile),
544+
)
545+
return result_key
546+
547+
496548
# The process that has the lowest device ID should share FDO profile before
497549
# compilation with other processes.
498550
def _share_fdo_profiles(
@@ -510,32 +562,39 @@ def _share_fdo_profiles(
510562
return fdo_profile
511563

512564
compile_options.executable_build_options.fdo_profile = b""
513-
profile_key = (
514-
compilation_cache.get_cache_key(
515-
computation,
516-
devices,
517-
compile_options,
518-
backend,
519-
cache_key_type.IgnoreCallbacks.ALL,
520-
)
521-
+ "_fdo_sync"
522-
)
565+
try:
566+
profile_key = (
567+
compilation_cache.get_cache_key(
568+
computation,
569+
devices,
570+
compile_options,
571+
backend,
572+
cache_key_type.IgnoreCallbacks.ALL,
573+
)
574+
+ "_fdo_sync"
575+
)
576+
except xc._xla.XlaRuntimeError as ex:
577+
logger.error(
578+
"compile_or_get_cached: unable to generate cache key, "
579+
"skipping the fdo profile sharing: %s",
580+
ex,
581+
)
582+
return fdo_profile
583+
523584
if profile_key in _share_fdo_profiles.modules_profiles:
524585
return _share_fdo_profiles.modules_profiles[profile_key]
525586

526587
share_timeout = config.share_binary_between_hosts_timeout_ms.value
527588
if distributed.global_state.process_id == min_process_id:
528589
logger.debug(
529-
"Sharing FDO profile: %s. For module %s. Process %d.",
530-
fdo_profile,
590+
"Module %s. Sharing FDO profile. Process %d.",
531591
module_name,
532592
min_process_id,
533593
)
534594
global_client.key_value_set_bytes(profile_key, fdo_profile)
535595
else:
536596
logger.debug(
537-
"Waiting for FDO profile: %s. For module %s. Should be set by process %d.",
538-
fdo_profile,
597+
"Module %s. Waiting for FDO profile which should be set by process %d.",
539598
module_name,
540599
min_process_id,
541600
)

jax/_src/cudnn/fused_attention_stablehlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import math
1919

2020
import jax
21-
from jax import core
2221
from jax import dtypes
22+
from jax._src import core
2323
from jax._src import dispatch
2424
from jax._src.custom_partitioning import custom_partitioning
2525
from jax._src.interpreters import batching

jax/_src/cudnn/fusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import functools
1616
import jax
17-
from jax import core as jax_core
17+
from jax._src import core as jax_core
1818
from jax.interpreters import mlir
1919
from jax.interpreters.mlir import hlo
2020
from jax.interpreters.mlir import ir

0 commit comments

Comments
 (0)