2222from tensorrt_llm .logger import logger
2323from tensorrt_llm .mapping import Mapping
2424
25+ # Unique tag to avoid collisions with other comms
26+ PP_COMM_TAG_AUTOTUNING = 30000
27+
2528
2629class DistributedTuningStrategy (enum .Enum ):
2730 """
@@ -358,7 +361,7 @@ class AutoTunerProfilingCache:
358361 """
359362
360363 def __init__ (self ):
361- self .cache = {}
364+ self .cache : Dict [ Tuple , Tuple ] = dict ()
362365
363366 # Cache metadata for local storage and validation
364367 self .lib_version = tensorrt_llm .__version__
@@ -430,7 +433,7 @@ def get_cache_key(
430433 ),
431434 )
432435
433- def merge_cache_data (self , cache_data : Dict [str , Any ]):
436+ def merge_cache_data (self , cache_data : Dict [Tuple , Tuple ]):
434437 self .cache .update (cache_data )
435438
436439 def get_specific_custom_op (self , custom_op : str ) -> Dict [Tuple , Tuple ]:
@@ -615,7 +618,10 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
615618 self ._last_capture : Optional ['AutoTuner.TacticsCapture' ] = None
616619
617620 # Dsitributed tuning state
621+ self ._map_op_to_distributed_strategy : Dict [
622+ str , DistributedTuningStrategy ] = {}
618623 self ._dist : Optional [Distributed ] = None
624+ self ._has_received_cache : bool = False
619625 self .mapping : Mapping = Mapping ()
620626
621627 @classmethod
@@ -624,9 +630,6 @@ def get(cls):
624630 cls ._instance = AutoTuner ()
625631 return cls ._instance
626632
627- def set_mapping (self , mapping : Mapping = None ):
628- self .mapping = mapping
629-
630633 class TacticsCapture :
631634 """Object returned by capture() that can be iterated to get all tactic combinations.
632635
@@ -797,10 +800,18 @@ def choose_one(
797800 if self .is_tuning_mode and is_cache_hit :
798801 return (runners [best_runner_id ], best_tactic )
799802
803+ # PP rank does not have cache hit, so we try to receive the cache from the previous rank
804+ # Notice that only under tuning mode, pp_recv will be called
805+ self .cache_pp_recv ()
806+
800807 assert len (runners ) > 0 , "At least one runner is required"
801808 assert all ([isinstance (r , TunableRunner ) for r in runners ]), \
802809 "All Given runners must be subclass of TunableRunner"
803810
811+ # Record the distributed tuning strategy for the custom_op
812+ self ._map_op_to_distributed_strategy [
813+ custom_op ] = tuning_config .distributed_tuning_strategy
814+
804815 tuning_start_time = time .perf_counter ()
805816 profiles = self ._optimization_profiles (tuning_config , inputs )
806817
@@ -1507,3 +1518,32 @@ def _should_current_rank_tune(self,
15071518 f"[AutoTuner] Unknown distributed tuning strategy: { strategy } , falling back to independent"
15081519 )
15091520 return True
1521+
1522+ def cache_pp_recv (self ):
1523+ if self .mapping .has_pp () and not self .mapping .is_first_pp_rank (
1524+ ) and not self ._has_received_cache :
1525+ self ._debug_logger (
1526+ f"[AutoTuner] Receiving cache data from previous pp rank { self .mapping .prev_pp_rank ()} "
1527+ )
1528+ profiling_cache = self ._dist .recv_object (
1529+ src = self .mapping .prev_pp_rank (),
1530+ tag = PP_COMM_TAG_AUTOTUNING ,
1531+ )
1532+ # Guarantee that only receive cache once during a single warm-up run
1533+ # Notice that this flag should be reset after each warm-up run because isend is always called
1534+ self ._has_received_cache = True
1535+ self .profiling_cache .merge_cache_data (profiling_cache )
1536+
1537+ def cache_pp_send (self ):
1538+ if self .mapping .has_pp () and not self .mapping .is_last_pp_rank ():
1539+ self ._debug_logger (
1540+ f"[AutoTuner] Sending cache data to next pp rank { self .mapping .next_pp_rank ()} "
1541+ )
1542+ self ._dist .isend_object (
1543+ self .profiling_cache .cache ,
1544+ dest = self .mapping .next_pp_rank (),
1545+ tag = PP_COMM_TAG_AUTOTUNING ,
1546+ ).wait ()
1547+
1548+ def clean_pp_flag (self ):
1549+ self ._has_received_cache = False
0 commit comments