Skip to content

Commit dcf589b

Browse files
committed
[TRTLLM-9615][feat] Support PP in the distributed tuning system
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 33a90f2 commit dcf589b

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from tensorrt_llm.logger import logger
2323
from tensorrt_llm.mapping import Mapping
2424

25+
# Unique tag to avoid collisions with other comms
26+
PP_COMM_TAG_AUTOTUNING = 22000
27+
2528

2629
class DistributedTuningStrategy(enum.Enum):
2730
"""
@@ -358,7 +361,7 @@ class AutoTunerProfilingCache:
358361
"""
359362

360363
def __init__(self):
361-
self.cache = {}
364+
self.cache: Dict[Tuple, Tuple] = dict()
362365

363366
# Cache metadata for local storage and validation
364367
self.lib_version = tensorrt_llm.__version__
@@ -430,7 +433,7 @@ def get_cache_key(
430433
),
431434
)
432435

433-
def merge_cache_data(self, cache_data: Dict[str, Any]):
436+
def merge_cache_data(self, cache_data: Dict[Tuple, Tuple]):
434437
self.cache.update(cache_data)
435438

436439
def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]:
@@ -615,7 +618,10 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
615618
self._last_capture: Optional['AutoTuner.TacticsCapture'] = None
616619

617620
# Dsitributed tuning state
621+
self._map_op_to_distributed_strategy: Dict[
622+
str, DistributedTuningStrategy] = {}
618623
self._dist: Optional[Distributed] = None
624+
self._has_received_cache: bool = False
619625
self.mapping: Mapping = Mapping()
620626

621627
@classmethod
@@ -797,10 +803,17 @@ def choose_one(
797803
if self.is_tuning_mode and is_cache_hit:
798804
return (runners[best_runner_id], best_tactic)
799805

806+
# Reach here means this PP rank does not have cache hit, so we need to wait for a recv
807+
self.cache_pp_recv()
808+
800809
assert len(runners) > 0, "At least one runner is required"
801810
assert all([isinstance(r, TunableRunner) for r in runners]), \
802811
"All Given runners must be subclass of TunableRunner"
803812

813+
# Record the distributed tuning strategy for the custom_op
814+
self._map_op_to_distributed_strategy[
815+
custom_op] = tuning_config.distributed_tuning_strategy
816+
804817
tuning_start_time = time.perf_counter()
805818
profiles = self._optimization_profiles(tuning_config, inputs)
806819

@@ -1507,3 +1520,30 @@ def _should_current_rank_tune(self,
15071520
f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent"
15081521
)
15091522
return True
1523+
1524+
def cache_pp_recv(self):
1525+
if self.mapping.has_pp() and not self.mapping.is_first_pp_rank(
1526+
) and self._dist is not None and not self._has_received_cache:
1527+
self._debug_logger(
1528+
f"[AutoTuner] Receiving cache data from previous pp rank {self.mapping.prev_pp_rank()}"
1529+
)
1530+
profiling_cache = self._dist.recv_object(
1531+
self.mapping.prev_pp_rank(),
1532+
tag=PP_COMM_TAG_AUTOTUNING,
1533+
)
1534+
# Every time a PP rank only receives a cache data once. After that, it should not do any further receiving
1535+
self._has_received_cache = True
1536+
self.profiling_cache.merge_cache_data(profiling_cache)
1537+
1538+
def cache_pp_send(self):
1539+
# Send all cache contents to the next PP rank
1540+
if self.mapping.has_pp(
1541+
) and not self.mapping.is_last_pp_rank() and self._dist is not None:
1542+
self._debug_logger(
1543+
f"[AutoTuner] Sending cache data to next pp rank {self.mapping.next_pp_rank()}"
1544+
)
1545+
self._dist.send_object(
1546+
self.profiling_cache.cache,
1547+
self.mapping.next_pp_rank(),
1548+
tag=PP_COMM_TAG_AUTOTUNING,
1549+
)

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,14 @@ def _(
693693

694694
class NVFP4GemmUnifiedRunner(TunableRunner):
695695
runner_dict = dict()
696+
tuning_config = TuningConfig(
697+
dynamic_tensor_specs=(DynamicTensorSpec(
698+
0, 0, get_last_power_of_2_num_tokens_buckets,
699+
last_positive_power_of_2), ),
700+
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
701+
# nested tuning should always be independent
702+
distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT,
703+
)
696704

697705
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
698706
allowed_backends: List[str]):
@@ -943,7 +951,7 @@ def nvfp4_gemm(
943951
_, best_tactic = tuner.choose_one(
944952
"trtllm::nvfp4_gemm::gemm",
945953
[runner],
946-
FP4GemmRunner.
954+
NVFP4GemmUnifiedRunner.
947955
tuning_config, # All runners use the same tuning_config
948956
[act_fp4, weight, act_sf, weight_scale, alpha],
949957
)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,9 +649,14 @@ def _run_autotuner_warmup(self, resource_manager: ResourceManager):
649649
if self.is_draft_model and isinstance(
650650
spec_resource_manager, Eagle3ResourceManager):
651651
spec_resource_manager.is_first_draft = True
652+
652653
self.forward(batch,
653654
new_tensors_device=None,
654655
resource_manager=resource_manager)
656+
657+
# Sync the cache after the tuning process
658+
AutoTuner.get().cache_pp_send()
659+
655660
torch.cuda.synchronize()
656661

657662
logger.info(

0 commit comments

Comments
 (0)