11import ast
22import contextlib
33import copy
4+ import enum
45import inspect
56import itertools
67import json
1617from cuda .bindings import driver
1718
1819import tensorrt_llm
20+ from tensorrt_llm ._torch .distributed import Distributed
1921from tensorrt_llm .bindings .internal .runtime import delay_kernel
2022from tensorrt_llm .logger import logger
23+ from tensorrt_llm .mapping import Mapping
24+
25+
26+ class DistributedTuningStrategy (enum .Enum ):
27+ """
28+ Strategy for distributed tuning.
29+ Args:
30+ BROADCAST: One rank (rank 0) tunes and broadcasts results to others
31+ INDEPENDENT: Each rank tunes independently (default for non-comm ops)
32+ MERGE: All ranks participate in tuning and reach merge
33+ PARALLEL: All ranks participate in tuning with partial tactics
34+ """
35+ BROADCAST = "broadcast"
36+ INDEPENDENT = "independent"
37+ MERGE = "merge"
38+ PARALLEL = "parallel"
2139
2240
2341@dataclass (slots = True , unsafe_hash = True )
@@ -99,13 +117,15 @@ class TuningConfig:
99117 This flag is to create circular buffer of input tensors to avoid L2 cache hits to simulate cold L2 cache.
100118 Notice that not all tuning processes can benefit from this feature.
101119 use_cuda_graph (bool): Whether to use CUDA graph for the tuning process.
120+ distributed_tuning_strategy (DistributedTuningStrategy): Strategy for distributed tuning.
102121 """
103122 dynamic_tensor_specs : Tuple [DynamicTensorSpec , ...] = ()
104123 constraint_specs : Tuple [ConstraintSpec , ...] = ()
105124 tune_max_num_tokens : int = None
106125 inputs_pre_hook : Callable = None
107126 use_cold_l2_cache : bool = False
108127 use_cuda_graph : bool = True
128+ distributed_tuning_strategy : DistributedTuningStrategy = DistributedTuningStrategy .INDEPENDENT
109129
110130
111131@dataclass (unsafe_hash = True )
@@ -229,7 +249,16 @@ def unique_id(self):
229249
230250
231251@contextlib .contextmanager
232- def autotune (tune_mode : bool = True , cache_path : str = None , rank : int = 0 ):
252+ def autotune (tune_mode : bool = True , cache_path : str = None ):
253+ """Context manager for autotuning with distributed support.
254+
255+ Args:
256+ tune_mode: Whether to enable tuning mode
257+ cache_path: Path to save/load cache files
258+ """
259+ autotuner = AutoTuner .get ()
260+ rank = autotuner .mapping .rank
261+
233262 # if cache_path is provided, use the rank-specific file
234263 tune_required = tune_mode
235264 if cache_path is not None :
@@ -242,25 +271,27 @@ def autotune(tune_mode: bool = True, cache_path: str = None, rank: int = 0):
242271 if file_exists :
243272 logger .info (
244273 f"[Autotuner] Loading cache from { cache_path_no_ext_rank } " )
245- AutoTuner . get () .profiling_cache .load_cache (cache_path_no_ext_rank )
274+ autotuner .profiling_cache .load_cache (cache_path_no_ext_rank )
246275
247276 # record the old tuning mode
248- old_mode = AutoTuner . get () .is_tuning_mode
249- AutoTuner . get () .is_tuning_mode = tune_required
277+ old_mode = autotuner .is_tuning_mode
278+ autotuner .is_tuning_mode = tune_required
250279 autotune_enabled = tune_required and not old_mode
280+
251281 if autotune_enabled :
252282 logger .info ("[Autotuner] Autotuning process starts ..." )
283+
253284 try :
254285 yield
255286 finally :
256- AutoTuner . get () .is_tuning_mode = old_mode
287+ autotuner .is_tuning_mode = old_mode
257288 if autotune_enabled :
258289 logger .info ("[Autotuner] Autotuning process ends" )
259290
260291 # save cache
261292 if cache_path is not None :
262293 logger .info (f"[Autotuner] Saving cache to { cache_path_no_ext_rank } " )
263- AutoTuner . get () .profiling_cache .save_cache (cache_path_no_ext_rank )
294+ autotuner .profiling_cache .save_cache (cache_path_no_ext_rank )
264295
265296
266297@dataclass
@@ -399,6 +430,9 @@ def get_cache_key(
399430 ),
400431 )
401432
433+ def merge_cache_data (self , cache_data : Dict [str , Any ]):
434+ self .cache .update (cache_data )
435+
402436 def get_specific_custom_op (self , custom_op : str ) -> Dict [Tuple , Tuple ]:
403437 return {k : v for k , v in self .cache .items () if k [0 ] == custom_op }
404438
@@ -561,6 +595,11 @@ class AutoTuner:
561595 _instance = None
562596
563597 def __init__ (self , warmup = 2 , repeat = 10 , stream_delay_micro_secs = 1000 ):
598+ # Increase log level for AutoTuner associated logger`
599+ self ._log_level_to_info = os .getenv (
600+ "TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO" , '0' ) == '1'
601+ self ._debug_logger = logger .info if self ._log_level_to_info else logger .debug
602+
564603 self .repeat = repeat
565604 self .warmup = warmup
566605 self .stream_delay_micro_secs = stream_delay_micro_secs
@@ -575,17 +614,19 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
575614 # Last captured choose_one() contexts
576615 self ._last_capture : Optional ['AutoTuner.TacticsCapture' ] = None
577616
578- # Increase log level for AutoTuner associated logger
579- self ._log_level_to_info = os .getenv (
580- "TLLM_AUTOTUNER_LOG_LEVEL_DEBUG_TO_INFO" , '0' ) == '1'
581- self ._debug_logger = logger .info if self ._log_level_to_info else logger .debug
617+ # Dsitributed tuning state
618+ self ._dist : Optional [Distributed ] = None
619+ self .mapping : Mapping = Mapping ()
582620
583621 @classmethod
584622 def get (cls ):
585623 if cls ._instance is None :
586624 cls ._instance = AutoTuner ()
587625 return cls ._instance
588626
627+ def set_mapping (self , mapping : Mapping = None ):
628+ self .mapping = mapping
629+
589630 class TacticsCapture :
590631 """Object returned by capture() that can be iterated to get all tactic combinations.
591632
@@ -768,42 +809,26 @@ def choose_one(
768809 self .stats .tuned_op_profiled_configs [custom_op ] = 0
769810 if custom_op not in self .stats .failed_profiling_count :
770811 self .stats .failed_profiling_count [custom_op ] = set ()
771- new_tuning_failure_occured = False
772-
773- for p in profiles :
774- tensors = self ._prepare_input_tensors (p , inputs )
775- is_cache_hit , * _ = self .profiling_cache .search_cache (
776- custom_op , runners , p .get_opt_shapes (), tuning_config )
777- if not is_cache_hit :
778- # Initialize runner and tactic as None in case of no valid tactic or runners are found
779- best_runner_id , best_tactic , min_time , has_tuning_failure_occured = self ._profile_runners (
780- custom_op , runners , tensors , p , tuning_config , ** kwargs )
781- if best_runner_id is not None :
782- # At least one valid (runner, tactic) pair is found
783- cache_key = self .profiling_cache .get_cache_key (
784- custom_op , runners [best_runner_id ], p .get_opt_shapes (),
785- tuning_config )
786-
787- self ._debug_logger (
788- f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
789- )
790- # inspect call stack
791- self .profiling_cache [cache_key ] = (best_runner_id ,
792- best_tactic , min_time )
793-
794- self .stats .tuned_op_profiled_configs [custom_op ] += 1
795- else :
796- logger .warning_once (
797- f"[Autotuner] No valid runner/tactic was found for custom_op={ custom_op } , input_shapes={ input_shapes } . "
798- f"At least one valid (runner, tactic) pair is required. "
799- f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
800- f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash." ,
801- key = (custom_op , "warning_autotuning_no_valid_tactic" ),
802- )
803- new_tuning_failure_occured = new_tuning_failure_occured or has_tuning_failure_occured
812+ new_tuning_failure_occurred = False
813+
814+ # Synchronize ranks before profiling
815+ if self ._should_current_rank_tune (
816+ tuning_config .distributed_tuning_strategy ):
817+ for p in profiles :
818+ tensors = self ._prepare_input_tensors (p , inputs )
819+ is_cache_hit , * _ = self .profiling_cache .search_cache (
820+ custom_op , runners , p .get_opt_shapes (), tuning_config )
821+ if not is_cache_hit :
822+ # Initialize runner and tactic as None in case of no valid tactic or runners are found
823+ best_runner_id , best_tactic , min_time , has_tuning_failure_occurred = self ._profile_runners (
824+ custom_op , runners , tensors , p , tuning_config , ** kwargs )
825+ new_tuning_failure_occurred = new_tuning_failure_occurred or has_tuning_failure_occurred
826+
827+ self ._maybe_sync_cache_data (tuning_config .distributed_tuning_strategy ,
828+ custom_op )
804829
805830 # If failed profiling tactics occurs, log the error.
806- if new_tuning_failure_occured :
831+ if new_tuning_failure_occurred :
807832 logger .warning_once (
808833 f"[Autotuner] New tuning error occurs:"
809834 f"Total failed profiling tactics occurs: { len (self .stats .failed_profiling_count [custom_op ])} for custom_op={ custom_op } . "
@@ -834,7 +859,7 @@ def _profile_runners(
834859 ** kwargs ,
835860 ) -> float :
836861 min_time = float ('inf' )
837- has_tuning_failure_occured = False
862+ has_tuning_failure_occurred = False
838863 best_runner_id , best_tactic = None , None
839864 # If the inputs_pre_hook is provided, it will be called before profiling.
840865 if tuning_config .inputs_pre_hook is not None :
@@ -845,8 +870,11 @@ def _profile_runners(
845870 p .name
846871 for p in inspect .signature (runner .forward ).parameters .values ()
847872 }
848- valid_tactics = runner .get_valid_tactics (input_tensors , profile ,
849- ** kwargs )
873+ all_valid_tactics = runner .get_valid_tactics (
874+ input_tensors , profile , ** kwargs )
875+
876+ valid_tactics = self ._maybe_parallelize_tactics (
877+ all_valid_tactics , tuning_config .distributed_tuning_strategy )
850878 if "do_preparation" in runner_arg_names and len (valid_tactics ) > 0 :
851879 runner (
852880 input_tensors ,
@@ -882,12 +910,36 @@ def _profile_runners(
882910 # Set time_measured to inf to notify the failure of the tactic. This can happen when `get_valid_tactics` mistakenly return wrong tactics
883911 # or some runtime error occurs during profiling.
884912 time_measured = float ('inf' )
885- has_tuning_failure_occured = True
913+ has_tuning_failure_occurred = True
886914 if time_measured < min_time :
887915 min_time = time_measured
888916 best_runner_id , best_tactic = runner_id , tac
889917
890- return best_runner_id , best_tactic , min_time , has_tuning_failure_occured
918+ if best_runner_id is not None :
919+ # At least one valid (runner, tactic) pair is found
920+ cache_key = self .profiling_cache .get_cache_key (
921+ custom_op , runners [best_runner_id ], profile .get_opt_shapes (),
922+ tuning_config )
923+
924+ self ._debug_logger (
925+ f"[Autotuner] Profiling runner={ runners [best_runner_id ]} , tactic={ best_tactic } for cache_key={ cache_key } ."
926+ )
927+ # inspect call stack
928+ # TODO: use named tuple to make it more readable
929+ self .profiling_cache [cache_key ] = (best_runner_id , best_tactic ,
930+ min_time )
931+
932+ self .stats .tuned_op_profiled_configs [custom_op ] += 1
933+ else :
934+ logger .warning_once (
935+ f"[Autotuner] No valid runner/tactic was found for custom_op={ custom_op } , input_shapes={ profile .get_opt_shapes ()} . "
936+ f"At least one valid (runner, tactic) pair is required. "
937+ f"If get_valid_tactics is intended to return empty list, please ensure that this profile is not valid for the custom_op "
938+ f"and should not occurs during the inference stage, or fallback tactic is implemented. Otherwise, the the tuning process will crash." ,
939+ key = (custom_op , "warning_autotuning_no_valid_tactic" ),
940+ )
941+
942+ return best_runner_id , best_tactic , min_time , has_tuning_failure_occurred
891943
892944 def _get_input_sizes (self , inputs : List [torch .Tensor ]) -> List [torch .Size ]:
893945
@@ -1358,3 +1410,103 @@ def _cudaGetErrorEnum(self, error) -> str:
13581410 return nvrtc .nvrtcGetErrorString (error )[1 ]
13591411 else :
13601412 raise RuntimeError ("Unknown error type: {}" .format (error ))
1413+
1414+ def setup_distributed_state (self , mapping : Mapping , dist : Distributed ):
1415+ """Setup distributed communication state for autotuning."""
1416+ self .mapping = mapping
1417+ self ._dist = dist
1418+ self ._debug_logger (
1419+ f"[AutoTuner] Whether using distributed tuning: { self ._is_distributed ()} "
1420+ )
1421+
1422+ def _is_distributed (self ) -> bool :
1423+ """Check if we're in a distributed environment."""
1424+ return self .mapping is not None and self .mapping .tp_size > 1 and self ._dist is not None
1425+
1426+ def _maybe_parallelize_tactics (
1427+ self , all_valid_tactics : List [Any ],
1428+ strategy : DistributedTuningStrategy ) -> List [Any ]:
1429+ """Parallelize tactics across all TP ranks if strategy is PARALLEL."""
1430+ if strategy == DistributedTuningStrategy .PARALLEL :
1431+ # only distribute across TP ranks
1432+ # each TP rank will only tune the tactics that are assigned to it
1433+ tp_size = self .mapping .tp_size
1434+ tp_rank = self .mapping .tp_rank
1435+ valid_tactics = []
1436+ for idx , tactic in enumerate (all_valid_tactics ):
1437+ if idx % tp_size == tp_rank :
1438+ valid_tactics .append (tactic )
1439+ return valid_tactics
1440+ else :
1441+ return all_valid_tactics
1442+
1443+ def _maybe_sync_cache_data (self , strategy : DistributedTuningStrategy ,
1444+ custom_op : str ):
1445+ """Synchronize cache data across all ranks."""
1446+ if not self ._is_distributed ():
1447+ logger .warning (
1448+ f"[AutoTuner] Not in distributed environment, skipping synchronization"
1449+ )
1450+ return
1451+
1452+ if strategy == DistributedTuningStrategy .BROADCAST :
1453+ self ._broadcast_cache_data (custom_op )
1454+ elif strategy == DistributedTuningStrategy .INDEPENDENT :
1455+ return
1456+ elif strategy == DistributedTuningStrategy .MERGE :
1457+ self ._merge_cache_data (custom_op )
1458+ elif strategy == DistributedTuningStrategy .PARALLEL :
1459+ self ._merge_cache_data (custom_op )
1460+ else :
1461+ logger .error (
1462+ f"[AutoTuner] Unknown distributed tuning strategy: { strategy } , falling back to independent"
1463+ )
1464+ return
1465+
1466+ def _merge_cache_data (self , custom_op : str ):
1467+ cache_data = self .profiling_cache .get_specific_custom_op (custom_op )
1468+ merged_cache_data = dict ()
1469+ all_cache_data = self ._dist .tp_allgather (obj = cache_data )
1470+
1471+ for data in all_cache_data :
1472+ for key , value in data .items ():
1473+ current_time = merged_cache_data .get (key , [
1474+ float ('inf' ),
1475+ ])[- 1 ]
1476+ if value [- 1 ] < current_time :
1477+ merged_cache_data [key ] = value
1478+
1479+ self .profiling_cache .merge_cache_data (merged_cache_data )
1480+
1481+ def _broadcast_cache_data (
1482+ self ,
1483+ custom_op : str ,
1484+ ) -> None :
1485+ """Broadcast tactics from root rank to all other ranks."""
1486+ cache_data = self .profiling_cache .get_specific_custom_op (custom_op )
1487+ root = 0
1488+ cache_data = self ._dist .tp_broadcast (obj = cache_data , root = root )
1489+
1490+ self .profiling_cache .merge_cache_data (cache_data )
1491+
1492+ def _should_current_rank_tune (self ,
1493+ strategy : DistributedTuningStrategy ) -> bool :
1494+ """Determine if this rank should perform tuning based on strategy."""
1495+ if not self ._is_distributed ():
1496+ return True
1497+
1498+ if strategy == DistributedTuningStrategy .BROADCAST :
1499+ # Only rank 0 tunes
1500+ return self .mapping .rank == 0
1501+ elif strategy in {
1502+ DistributedTuningStrategy .INDEPENDENT ,
1503+ DistributedTuningStrategy .MERGE ,
1504+ DistributedTuningStrategy .PARALLEL ,
1505+ }:
1506+ # All ranks tune independently
1507+ return True
1508+ else :
1509+ logger .error (
1510+ f"[AutoTuner] Unknown distributed tuning strategy: { strategy } , falling back to independent"
1511+ )
1512+ return True
0 commit comments