@@ -348,7 +348,33 @@ def compile_or_get_cached(
348348
349349 use_compilation_cache = compilation_cache .is_cache_used (backend )
350350
351+ is_multi_process = (
352+ len ({device .process_index for device in devices .flatten ()}) > 1
353+ )
354+ min_device_process_id = min (
355+ devices .flatten (), key = lambda device : device .id
356+ ).process_index
357+ is_auto_pgle_used = (
358+ config .enable_pgle .value and config .pgle_profiling_runs .value > 0
359+ )
360+
351361 if not use_compilation_cache :
362+ if (
363+ is_multi_process
364+ and is_auto_pgle_used
365+ and distributed .global_state .client is not None
366+ ):
367+ compile_options .executable_build_options .fdo_profile = (
368+ _share_fdo_profiles (
369+ computation ,
370+ devices ,
371+ compile_options ,
372+ backend ,
373+ distributed .global_state .client ,
374+ min_device_process_id ,
375+ )
376+ )
377+
352378 return backend_compile (backend , computation , compile_options ,
353379 host_callbacks )
354380
@@ -373,61 +399,18 @@ def compile_or_get_cached(
373399 return backend_compile (backend , computation , compile_options ,
374400 host_callbacks )
375401
376- is_multi_process = (
377- len ({device .process_index for device in devices .flatten ()}) > 1 )
378- min_device_process_id = (
379- min (devices .flatten (), key = lambda device : device .id ).process_index )
380-
381- # When PGLE is enabled there might be 3 types of situations:
382- # 1. PGLE profiled module (the one which was recompiled with FDO profile) is
383- # in the persistent cache. In this case the module should be returned from
384- # cache and PGLE should be disabled for this module. Is module is stored in
385- # the persistent cache under the "pgle_profiled_module_key" which calculated
386- # with replacing FDO profile with flag which identify that module were PGLE
387- # profiled.
388- # 2. PGLE profiled module is not in the persistent cache and the module is
389- # getting built with an FDO profile. In this case we need to share FDO profile
390- # with other processes and store the result under the
391- # "pgle_profiled_module_key" so later in case 1 we will be able to find the
392- # module.
393- # 3. PGLE profiled module is not in the persistent cache and the module is
394- # getting compiled to be PGLEd (FDO profile is empty). In this case we need to
395- # simply return the non-PGLE profiled module from the persistent cache.
396- if (config .enable_pgle .value
397- and config .pgle_profiling_runs .value > 0 ):
398- fdo_profile = compile_options .executable_build_options .fdo_profile
399- compile_options .executable_build_options .fdo_profile = b"pgle profiled"
400-
401- pgle_profiled_module_key = compilation_cache .get_cache_key (
402+ if is_auto_pgle_used :
403+ cache_key = _resolve_pgle_module_cache_key (
402404 computation ,
403405 devices ,
404406 compile_options ,
405407 backend ,
406- cache_key_type .IgnoreCallbacks .ALL ,
408+ pgle_profiler ,
409+ is_multi_process ,
410+ cache_key ,
411+ module_name ,
412+ min_device_process_id ,
407413 )
408- compile_options .executable_build_options .fdo_profile = fdo_profile
409-
410- if _is_executable_in_cache (backend , pgle_profiled_module_key ):
411- # Load PGLE profiled module from the persistent cache.
412- cache_key = pgle_profiled_module_key
413- if pgle_profiler is not None :
414- pgle_profiler .disable ()
415- elif fdo_profile is not None and len (fdo_profile ) > 0 :
416- # Store module under PGLE profiled module cache key.
417- cache_key = pgle_profiled_module_key
418- if is_multi_process and distributed .global_state .client is not None :
419- compile_options .executable_build_options .fdo_profile = _share_fdo_profiles (
420- computation , devices , compile_options , backend ,
421- distributed .global_state .client ,
422- min_device_process_id
423- )
424- else :
425- compile_options .executable_build_options .fdo_profile = fdo_profile
426- logger .debug (
427- "Compiling module %s with FDO profile: %s" ,
428- module_name ,
429- compile_options .executable_build_options .fdo_profile ,
430- )
431414
432415 cache_retrieval_start = time .monotonic ()
433416 retrieved_executable , retrieved_compile_time = _cache_read (
@@ -493,6 +476,75 @@ def compile_or_get_cached(
493476 cache_key ,
494477 )
495478
479+
480+ # When PGLE is enabled there might be 3 types of situations:
481+ # 1. PGLE profiled module (the one which was recompiled with FDO profile) is
482+ # in the persistent cache. In this case the module should be returned from
483+ # cache and PGLE should be disabled for this module. Is module is stored in
484+ # the persistent cache under the "pgle_profiled_module_key" which calculated
485+ # with replacing FDO profile with flag which identify that module were PGLE
486+ # profiled.
487+ # 2. PGLE profiled module is not in the persistent cache and the module is
488+ # getting built with an FDO profile. In this case we need to share FDO profile
489+ # with other processes and store the result under the
490+ # "pgle_profiled_module_key" so later in case 1 we will be able to find the
491+ # module.
492+ # 3. PGLE profiled module is not in the persistent cache and the module is
493+ # getting compiled to be PGLEd (FDO profile is empty). In this case we need to
494+ # simply return the non-PGLE profiled module from the persistent cache.
495+ def _resolve_pgle_module_cache_key (
496+ computation : ir .Module ,
497+ devices : np .ndarray ,
498+ compile_options : xc .CompileOptions ,
499+ backend : xc .Client ,
500+ pgle_profiler : profiler .PGLEProfiler | None ,
501+ is_multi_process : bool ,
502+ cache_key : str ,
503+ module_name : str ,
504+ min_device_process_id : int ,
505+ ) -> str :
506+ fdo_profile = compile_options .executable_build_options .fdo_profile
507+ compile_options .executable_build_options .fdo_profile = b"pgle profiled"
508+
509+ pgle_profiled_module_key = compilation_cache .get_cache_key (
510+ computation ,
511+ devices ,
512+ compile_options ,
513+ backend ,
514+ cache_key_type .IgnoreCallbacks .ALL ,
515+ )
516+ compile_options .executable_build_options .fdo_profile = fdo_profile
517+
518+ result_key = cache_key
519+ if _is_executable_in_cache (backend , pgle_profiled_module_key ):
520+ # Load PGLE profiled module from the persistent cache.
521+ result_key = pgle_profiled_module_key
522+ if pgle_profiler is not None :
523+ pgle_profiler .disable ()
524+ elif fdo_profile is not None and len (fdo_profile ) > 0 :
525+ # Store module under PGLE profiled module cache key.
526+ result_key = pgle_profiled_module_key
527+ if is_multi_process and distributed .global_state .client is not None :
528+ compile_options .executable_build_options .fdo_profile = (
529+ _share_fdo_profiles (
530+ computation ,
531+ devices ,
532+ compile_options ,
533+ backend ,
534+ distributed .global_state .client ,
535+ min_device_process_id ,
536+ )
537+ )
538+ else :
539+ compile_options .executable_build_options .fdo_profile = fdo_profile
540+ logger .debug (
541+ "Compiling module %s with FDO profile of length %d" ,
542+ module_name ,
543+ len (compile_options .executable_build_options .fdo_profile ),
544+ )
545+ return result_key
546+
547+
496548# The process that has the lowest device ID should share FDO profile before
497549# compilation with other processes.
498550def _share_fdo_profiles (
@@ -510,32 +562,39 @@ def _share_fdo_profiles(
510562 return fdo_profile
511563
512564 compile_options .executable_build_options .fdo_profile = b""
513- profile_key = (
514- compilation_cache .get_cache_key (
515- computation ,
516- devices ,
517- compile_options ,
518- backend ,
519- cache_key_type .IgnoreCallbacks .ALL ,
520- )
521- + "_fdo_sync"
522- )
565+ try :
566+ profile_key = (
567+ compilation_cache .get_cache_key (
568+ computation ,
569+ devices ,
570+ compile_options ,
571+ backend ,
572+ cache_key_type .IgnoreCallbacks .ALL ,
573+ )
574+ + "_fdo_sync"
575+ )
576+ except xc ._xla .XlaRuntimeError as ex :
577+ logger .error (
578+ "compile_or_get_cached: unable to generate cache key, "
579+ "skipping the fdo profile sharing: %s" ,
580+ ex ,
581+ )
582+ return fdo_profile
583+
523584 if profile_key in _share_fdo_profiles .modules_profiles :
524585 return _share_fdo_profiles .modules_profiles [profile_key ]
525586
526587 share_timeout = config .share_binary_between_hosts_timeout_ms .value
527588 if distributed .global_state .process_id == min_process_id :
528589 logger .debug (
529- "Sharing FDO profile: %s. For module %s. Process %d." ,
530- fdo_profile ,
590+ "Module %s. Sharing FDO profile. Process %d." ,
531591 module_name ,
532592 min_process_id ,
533593 )
534594 global_client .key_value_set_bytes (profile_key , fdo_profile )
535595 else :
536596 logger .debug (
537- "Waiting for FDO profile: %s. For module %s. Should be set by process %d." ,
538- fdo_profile ,
597+ "Module %s. Waiting for FDO profile which should be set by process %d." ,
539598 module_name ,
540599 min_process_id ,
541600 )
0 commit comments