Skip to content

Commit 64fcb9d

Browse files
dougalmGoogle-ML-Automation
authored andcommitted
Fix pgle profiling, broken in previous change.
PiperOrigin-RevId: 695762690
1 parent b185e64 commit 64fcb9d

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

jax/_src/pjit.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,15 +1608,22 @@ def _resolve_and_lower(
16081608
lowering_parameters=lowering_parameters,
16091609
pgle_profiler=pgle_profiler)
16101610

1611+
_pgle_profiler_dict = weakref.WeakKeyDictionary() # type: ignore
1612+
16111613
def _pjit_call_impl_python(
16121614
*args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
16131615
resource_env, donated_invars, name, keep_unused, inline,
16141616
compiler_options_kvs):
16151617
pgle_compile_options, pgle_profiler = {}, None
16161618
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
1617-
pgle_profiler = profiler.PGLEProfiler(
1618-
config.pgle_profiling_runs.value,
1619-
config.pgle_aggregation_percentile.value)
1619+
compilation_target_key = jaxpr
1620+
pgle_profiler = _pgle_profiler_dict.get(compilation_target_key)
1621+
if pgle_profiler is None:
1622+
pgle_profiler = profiler.PGLEProfiler(
1623+
config.pgle_profiling_runs.value,
1624+
config.pgle_aggregation_percentile.value)
1625+
_pgle_profiler_dict[compilation_target_key] = pgle_profiler
1626+
16201627
# The method below will return FDO profile when module was profiled
16211628
# config.jax_pgle_profiling_runs amount of times, otherwise the result will
16221629
# be None.

0 commit comments

Comments
 (0)