@@ -125,7 +125,7 @@ class TuningConfig:
125125 inputs_pre_hook : Callable = None
126126 use_cold_l2_cache : bool = False
127127 use_cuda_graph : bool = True
128- distributed_tuning_strategy : DistributedTuningStrategy = DistributedTuningStrategy .INDEPENDENT
128+ distributed_tuning_strategy : DistributedTuningStrategy = DistributedTuningStrategy .BROADCAST
129129
130130
131131@dataclass (unsafe_hash = True )
@@ -358,7 +358,7 @@ class AutoTunerProfilingCache:
358358 """
359359
360360 def __init__ (self ):
361- self .cache = {}
361+ self .cache : Dict [ Tuple , Tuple ] = dict ()
362362
363363 # Cache metadata for local storage and validation
364364 self .lib_version = tensorrt_llm .__version__
@@ -430,7 +430,7 @@ def get_cache_key(
430430 ),
431431 )
432432
433- def merge_cache_data (self , cache_data : Dict [str , Any ]):
433+ def merge_cache_data (self , cache_data : Dict [Tuple , Tuple ]):
434434 self .cache .update (cache_data )
435435
436436 def get_specific_custom_op (self , custom_op : str ) -> Dict [Tuple , Tuple ]:
@@ -615,6 +615,8 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
615615 self ._last_capture : Optional ['AutoTuner.TacticsCapture' ] = None
616616
617617 # Dsitributed tuning state
618+ self ._map_op_to_distributed_strategy : Dict [
619+ str , DistributedTuningStrategy ] = {}
618620 self ._dist : Optional [Distributed ] = None
619621 self .mapping : Mapping = Mapping ()
620622
@@ -801,6 +803,10 @@ def choose_one(
801803 assert all ([isinstance (r , TunableRunner ) for r in runners ]), \
802804 "All Given runners must be subclass of TunableRunner"
803805
806+ # Record the distributed tuning strategy for the custom_op
807+ self ._map_op_to_distributed_strategy [
808+ custom_op ] = tuning_config .distributed_tuning_strategy
809+
804810 tuning_start_time = time .perf_counter ()
805811 profiles = self ._optimization_profiles (tuning_config , inputs )
806812
@@ -1510,3 +1516,24 @@ def _should_current_rank_tune(self,
15101516 f"[AutoTuner] Unknown distributed tuning strategy: { strategy } , falling back to independent"
15111517 )
15121518 return True
1519+
1520+ def cache_sync_pp_recv (self ):
1521+ if self .mapping .has_pp () and not self .mapping .is_first_pp_rank :
1522+ profiling_cache = self ._dist .recv_object (
1523+ self .mapping .prev_pp_rank ())
1524+ self .profiling_cache .merge_cache_data (profiling_cache )
1525+
1526+ def cache_sync_pp_send (self ):
1527+ # Op with INDEPENDENT strategy shall not be send
1528+ if self .mapping .has_pp () and not self .mapping .is_last_pp_rank :
1529+ dependent_custom_ops = [
1530+ op for op , strategy in
1531+ self ._map_op_to_distributed_strategy .items ()
1532+ if strategy != DistributedTuningStrategy .INDEPENDENT
1533+ ]
1534+ dependent_custom_ops_cache = dict ()
1535+ for op in dependent_custom_ops :
1536+ dependent_custom_ops_cache .update (
1537+ self .profiling_cache .get_specific_custom_op (op ))
1538+ self ._dist .send_object (dependent_custom_ops_cache ,
1539+ self .mapping .next_pp_rank ())
0 commit comments