@@ -374,60 +374,24 @@ def compile_or_get_cached(
374374 host_callbacks )
375375
376376 is_multi_process = (
377- len ({device .process_index for device in devices .flatten ()}) > 1 )
378- min_device_process_id = (
379- min (devices .flatten (), key = lambda device : device .id ).process_index )
380-
381- # When PGLE is enabled there might be 3 types of situations:
382- # 1. PGLE profiled module (the one which was recompiled with FDO profile) is
383- # in the persistent cache. In this case the module should be returned from
384- # cache and PGLE should be disabled for this module. Is module is stored in
385- # the persistent cache under the "pgle_profiled_module_key" which calculated
386- # with replacing FDO profile with flag which identify that module were PGLE
387- # profiled.
388- # 2. PGLE profiled module is not in the persistent cache and the module is
389- # getting built with an FDO profile. In this case we need to share FDO profile
390- # with other processes and store the result under the
391- # "pgle_profiled_module_key" so later in case 1 we will be able to find the
392- # module.
393- # 3. PGLE profiled module is not in the persistent cache and the module is
394- # getting compiled to be PGLEd (FDO profile is empty). In this case we need to
395- # simply return the non-PGLE profiled module from the persistent cache.
396- if (config .enable_pgle .value
397- and config .pgle_profiling_runs .value > 0 ):
398- fdo_profile = compile_options .executable_build_options .fdo_profile
399- compile_options .executable_build_options .fdo_profile = b"pgle profiled"
400-
401- pgle_profiled_module_key = compilation_cache .get_cache_key (
377+ len ({device .process_index for device in devices .flatten ()}) > 1
378+ )
379+ min_device_process_id = min (
380+ devices .flatten (), key = lambda device : device .id
381+ ).process_index
382+
383+ if config .enable_pgle .value and config .pgle_profiling_runs .value > 0 :
384+ cache_key = _resolve_pgle_module_cache_key (
402385 computation ,
403386 devices ,
404387 compile_options ,
405388 backend ,
406- cache_key_type .IgnoreCallbacks .ALL ,
389+ pgle_profiler ,
390+ is_multi_process ,
391+ cache_key ,
392+ module_name ,
393+ min_device_process_id ,
407394 )
408- compile_options .executable_build_options .fdo_profile = fdo_profile
409-
410- if _is_executable_in_cache (backend , pgle_profiled_module_key ):
411- # Load PGLE profiled module from the persistent cache.
412- cache_key = pgle_profiled_module_key
413- if pgle_profiler is not None :
414- pgle_profiler .disable ()
415- elif fdo_profile is not None and len (fdo_profile ) > 0 :
416- # Store module under PGLE profiled module cache key.
417- cache_key = pgle_profiled_module_key
418- if is_multi_process and distributed .global_state .client is not None :
419- compile_options .executable_build_options .fdo_profile = _share_fdo_profiles (
420- computation , devices , compile_options , backend ,
421- distributed .global_state .client ,
422- min_device_process_id
423- )
424- else :
425- compile_options .executable_build_options .fdo_profile = fdo_profile
426- logger .debug (
427- "Compiling module %s with FDO profile: %s" ,
428- module_name ,
429- compile_options .executable_build_options .fdo_profile ,
430- )
431395
432396 cache_retrieval_start = time .monotonic ()
433397 retrieved_executable , retrieved_compile_time = _cache_read (
@@ -493,6 +457,75 @@ def compile_or_get_cached(
493457 cache_key ,
494458 )
495459
460+
461+ # When PGLE is enabled there might be 3 types of situations:
462+ # 1. PGLE profiled module (the one which was recompiled with FDO profile) is
463+ # in the persistent cache. In this case the module should be returned from
464+ # cache and PGLE should be disabled for this module. Is module is stored in
465+ # the persistent cache under the "pgle_profiled_module_key" which calculated
466+ # with replacing FDO profile with flag which identify that module were PGLE
467+ # profiled.
468+ # 2. PGLE profiled module is not in the persistent cache and the module is
469+ # getting built with an FDO profile. In this case we need to share FDO profile
470+ # with other processes and store the result under the
471+ # "pgle_profiled_module_key" so later in case 1 we will be able to find the
472+ # module.
473+ # 3. PGLE profiled module is not in the persistent cache and the module is
474+ # getting compiled to be PGLEd (FDO profile is empty). In this case we need to
475+ # simply return the non-PGLE profiled module from the persistent cache.
476+ def _resolve_pgle_module_cache_key (
477+ computation : ir .Module ,
478+ devices : np .ndarray ,
479+ compile_options : xc .CompileOptions ,
480+ backend : xc .Client ,
481+ pgle_profiler : profiler .PGLEProfiler | None ,
482+ is_multi_process : bool ,
483+ cache_key : str ,
484+ module_name : str ,
485+ min_device_process_id : int ,
486+ ) -> str :
487+ fdo_profile = compile_options .executable_build_options .fdo_profile
488+ compile_options .executable_build_options .fdo_profile = b"pgle profiled"
489+
490+ pgle_profiled_module_key = compilation_cache .get_cache_key (
491+ computation ,
492+ devices ,
493+ compile_options ,
494+ backend ,
495+ cache_key_type .IgnoreCallbacks .ALL ,
496+ )
497+ compile_options .executable_build_options .fdo_profile = fdo_profile
498+
499+ result_key = cache_key
500+ if _is_executable_in_cache (backend , pgle_profiled_module_key ):
501+ # Load PGLE profiled module from the persistent cache.
502+ result_key = pgle_profiled_module_key
503+ if pgle_profiler is not None :
504+ pgle_profiler .disable ()
505+ elif fdo_profile is not None and len (fdo_profile ) > 0 :
506+ # Store module under PGLE profiled module cache key.
507+ result_key = pgle_profiled_module_key
508+ if is_multi_process and distributed .global_state .client is not None :
509+ compile_options .executable_build_options .fdo_profile = (
510+ _share_fdo_profiles (
511+ computation ,
512+ devices ,
513+ compile_options ,
514+ backend ,
515+ distributed .global_state .client ,
516+ min_device_process_id ,
517+ )
518+ )
519+ else :
520+ compile_options .executable_build_options .fdo_profile = fdo_profile
521+ logger .debug (
522+ "Compiling module %s with FDO profile of length %d" ,
523+ module_name ,
524+ len (compile_options .executable_build_options .fdo_profile ),
525+ )
526+ return result_key
527+
528+
496529# The process that has the lowest device ID should share FDO profile before
497530# compilation with other processes.
498531def _share_fdo_profiles (
@@ -510,32 +543,39 @@ def _share_fdo_profiles(
510543 return fdo_profile
511544
512545 compile_options .executable_build_options .fdo_profile = b""
513- profile_key = (
514- compilation_cache .get_cache_key (
515- computation ,
516- devices ,
517- compile_options ,
518- backend ,
519- cache_key_type .IgnoreCallbacks .ALL ,
520- )
521- + "_fdo_sync"
522- )
546+ try :
547+ profile_key = (
548+ compilation_cache .get_cache_key (
549+ computation ,
550+ devices ,
551+ compile_options ,
552+ backend ,
553+ cache_key_type .IgnoreCallbacks .ALL ,
554+ )
555+ + "_fdo_sync"
556+ )
557+ except xc ._xla .XlaRuntimeError as ex :
558+ logger .error (
559+ "compile_or_get_cached: unable to generate cache key, "
560+ "skipping the fdo profile sharing: %s" ,
561+ ex ,
562+ )
563+ return fdo_profile
564+
523565 if profile_key in _share_fdo_profiles .modules_profiles :
524566 return _share_fdo_profiles .modules_profiles [profile_key ]
525567
526568 share_timeout = config .share_binary_between_hosts_timeout_ms .value
527569 if distributed .global_state .process_id == min_process_id :
528570 logger .debug (
529- "Sharing FDO profile: %s. For module %s. Process %d." ,
530- fdo_profile ,
571+ "Module %s. Sharing FDO profile. Process %d." ,
531572 module_name ,
532573 min_process_id ,
533574 )
534575 global_client .key_value_set_bytes (profile_key , fdo_profile )
535576 else :
536577 logger .debug (
537- "Waiting for FDO profile: %s. For module %s. Should be set by process %d." ,
538- fdo_profile ,
578+ "Module %s. Waiting for FDO profile which should be set by process %d." ,
539579 module_name ,
540580 min_process_id ,
541581 )
0 commit comments