File tree Expand file tree Collapse file tree 1 file changed +27
-8
lines changed Expand file tree Collapse file tree 1 file changed +27
-8
lines changed Original file line number Diff line number Diff line change @@ -348,7 +348,33 @@ def compile_or_get_cached(
348348
349349 use_compilation_cache = compilation_cache .is_cache_used (backend )
350350
351+ is_multi_process = (
352+ len ({device .process_index for device in devices .flatten ()}) > 1
353+ )
354+ min_device_process_id = min (
355+ devices .flatten (), key = lambda device : device .id
356+ ).process_index
357+ is_auto_pgle_used = (
358+ config .enable_pgle .value and config .pgle_profiling_runs .value > 0
359+ )
360+
351361 if not use_compilation_cache :
362+ if (
363+ is_multi_process
364+ and is_auto_pgle_used
365+ and distributed .global_state .client is not None
366+ ):
367+ compile_options .executable_build_options .fdo_profile = (
368+ _share_fdo_profiles (
369+ computation ,
370+ devices ,
371+ compile_options ,
372+ backend ,
373+ distributed .global_state .client ,
374+ min_device_process_id ,
375+ )
376+ )
377+
352378 return backend_compile (backend , computation , compile_options ,
353379 host_callbacks )
354380
@@ -373,14 +399,7 @@ def compile_or_get_cached(
373399 return backend_compile (backend , computation , compile_options ,
374400 host_callbacks )
375401
376- is_multi_process = (
377- len ({device .process_index for device in devices .flatten ()}) > 1
378- )
379- min_device_process_id = min (
380- devices .flatten (), key = lambda device : device .id
381- ).process_index
382-
383- if config .enable_pgle .value and config .pgle_profiling_runs .value > 0 :
402+ if is_auto_pgle_used :
384403 cache_key = _resolve_pgle_module_cache_key (
385404 computation ,
386405 devices ,
You can’t perform that action at this time.
0 commit comments