@@ -1608,15 +1608,22 @@ def _resolve_and_lower(
16081608 lowering_parameters = lowering_parameters ,
16091609 pgle_profiler = pgle_profiler )
16101610
1611+ _pgle_profiler_dict = weakref .WeakKeyDictionary () # type: ignore
1612+
16111613def _pjit_call_impl_python (
16121614 * args , jaxpr , in_shardings , out_shardings , in_layouts , out_layouts ,
16131615 resource_env , donated_invars , name , keep_unused , inline ,
16141616 compiler_options_kvs ):
16151617 pgle_compile_options , pgle_profiler = {}, None
16161618 if config .enable_pgle .value and config .pgle_profiling_runs .value > 0 :
1617- pgle_profiler = profiler .PGLEProfiler (
1618- config .pgle_profiling_runs .value ,
1619- config .pgle_aggregation_percentile .value )
1619+ compilation_target_key = jaxpr
1620+ pgle_profiler = _pgle_profiler_dict .get (compilation_target_key )
1621+ if pgle_profiler is None :
1622+ pgle_profiler = profiler .PGLEProfiler (
1623+ config .pgle_profiling_runs .value ,
1624+ config .pgle_aggregation_percentile .value )
1625+ _pgle_profiler_dict [compilation_target_key ] = pgle_profiler
1626+
16201627 # The method below will return FDO profile when module was profiled
16211628 # config.jax_pgle_profiling_runs amount of times, otherwise the result will
16221629 # be None.
0 commit comments