Skip to content

Commit ebae79d

Browse files
committed
[TRTLLM-9615][feat] Support PP in the distributed tuning system
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 5d71f66 commit ebae79d

File tree

3 files changed

+46
-4
lines changed

3 files changed

+46
-4
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class TuningConfig:
125125
inputs_pre_hook: Callable = None
126126
use_cold_l2_cache: bool = False
127127
use_cuda_graph: bool = True
128-
distributed_tuning_strategy: DistributedTuningStrategy = DistributedTuningStrategy.INDEPENDENT
128+
distributed_tuning_strategy: DistributedTuningStrategy = DistributedTuningStrategy.BROADCAST
129129

130130

131131
@dataclass(unsafe_hash=True)
@@ -358,7 +358,7 @@ class AutoTunerProfilingCache:
358358
"""
359359

360360
def __init__(self):
361-
self.cache = {}
361+
self.cache: Dict[Tuple, Tuple] = dict()
362362

363363
# Cache metadata for local storage and validation
364364
self.lib_version = tensorrt_llm.__version__
@@ -430,7 +430,7 @@ def get_cache_key(
430430
),
431431
)
432432

433-
def merge_cache_data(self, cache_data: Dict[str, Any]):
433+
def merge_cache_data(self, cache_data: Dict[Tuple, Tuple]):
434434
self.cache.update(cache_data)
435435

436436
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
@@ -615,6 +615,8 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
615615
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
616616

617617
# Dsitributed tuning state
618+
self._map_op_to_distributed_strategy: Dict[
619+
str, DistributedTuningStrategy] = {}
618620
self._dist: Optional[Distributed] = None
619621
self.mapping: Mapping = Mapping()
620622

@@ -801,6 +803,10 @@ def choose_one(
801803
assert all([isinstance(r, TunableRunner) for r in runners]), \
802804
"All Given runners must be subclass of TunableRunner"
803805

806+
# Record the distributed tuning strategy for the custom_op
807+
self._map_op_to_distributed_strategy[
808+
custom_op] = tuning_config.distributed_tuning_strategy
809+
804810
tuning_start_time = time.perf_counter()
805811
profiles = self._optimization_profiles(tuning_config, inputs)
806812

@@ -1510,3 +1516,24 @@ def _should_current_rank_tune(self,
15101516
f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent"
15111517
)
15121518
return True
1519+
1520+
def cache_sync_pp_recv(self):
1521+
if self.mapping.has_pp() and not self.mapping.is_first_pp_rank:
1522+
profiling_cache = self._dist.recv_object(
1523+
self.mapping.prev_pp_rank())
1524+
self.profiling_cache.merge_cache_data(profiling_cache)
1525+
1526+
def cache_sync_pp_send(self):
1527+
# Op with INDEPENDENT strategy shall not be send
1528+
if self.mapping.has_pp() and not self.mapping.is_last_pp_rank:
1529+
dependent_custom_ops = [
1530+
op for op, strategy in
1531+
self._map_op_to_distributed_strategy.items()
1532+
if strategy != DistributedTuningStrategy.INDEPENDENT
1533+
]
1534+
dependent_custom_ops_cache = dict()
1535+
for op in dependent_custom_ops:
1536+
dependent_custom_ops_cache.update(
1537+
self.profiling_cache.get_specific_custom_op(op))
1538+
self._dist.send_object(dependent_custom_ops_cache,
1539+
self.mapping.next_pp_rank())

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,14 @@ def _(
693693

694694
class NVFP4GemmUnifiedRunner(TunableRunner):
695695
runner_dict = dict()
696+
tuning_config = TuningConfig(
697+
dynamic_tensor_specs=(DynamicTensorSpec(
698+
0, 0, get_last_power_of_2_num_tokens_buckets,
699+
last_positive_power_of_2), ),
700+
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
701+
# nested tuning should always be independent
702+
distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT,
703+
)
696704

697705
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
698706
allowed_backends: List[str]):
@@ -943,7 +951,7 @@ def nvfp4_gemm(
943951
_, best_tactic = tuner.choose_one(
944952
"trtllm::nvfp4_gemm::gemm",
945953
[runner],
946-
FP4GemmRunner.
954+
NVFP4GemmUnifiedRunner.
947955
tuning_config, # All runners use the same tuning_config
948956
[act_fp4, weight, act_sf, weight_scale, alpha],
949957
)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,16 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
649649
if self.is_draft_model and isinstance(
650650
spec_resource_manager, Eagle3ResourceManager):
651651
spec_resource_manager.is_first_draft = True
652+
# Sync the cache before the forward pass for not the first pp ranks
653+
AutoTuner.get().cache_sync_pp_recv()
654+
652655
self.forward(batch,
653656
new_tensors_device=None,
654657
resource_manager=resource_manager)
658+
659+
# Sync the cache after the forward pass for not the last pp ranks
660+
AutoTuner.get().cache_sync_pp_send()
661+
655662
torch.cuda.synchronize()
656663

657664
logger.info(

0 commit comments

Comments
 (0)