|
22 | 22 | from tensorrt_llm.logger import logger |
23 | 23 | from tensorrt_llm.mapping import Mapping |
24 | 24 |
|
| 25 | +# Unique tag to avoid collisions with other comms |
| 26 | +PP_COMM_TAG_AUTOTUNING = 22000 |
| 27 | + |
25 | 28 |
|
26 | 29 | class DistributedTuningStrategy(enum.Enum): |
27 | 30 | """ |
@@ -358,7 +361,7 @@ class AutoTunerProfilingCache: |
358 | 361 | """ |
359 | 362 |
|
360 | 363 | def __init__(self): |
361 | | - self.cache = {} |
| 364 | + self.cache: Dict[Tuple, Tuple] = dict() |
362 | 365 |
|
363 | 366 | # Cache metadata for local storage and validation |
364 | 367 | self.lib_version = tensorrt_llm.__version__ |
@@ -430,7 +433,7 @@ def get_cache_key( |
430 | 433 | ), |
431 | 434 | ) |
432 | 435 |
|
433 | | - def merge_cache_data(self, cache_data: Dict[str, Any]): |
| 436 | + def merge_cache_data(self, cache_data: Dict[Tuple, Tuple]): |
434 | 437 | self.cache.update(cache_data) |
435 | 438 |
|
436 | 439 | def get_specific_custom_op(self, custom_op: str) -> Dict[Tuple, Tuple]: |
@@ -615,7 +618,10 @@ def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000): |
615 | 618 | self._last_capture: Optional['AutoTuner.TacticsCapture'] = None |
616 | 619 |
|
617 | 620 | # Dsitributed tuning state |
| 621 | + self._map_op_to_distributed_strategy: Dict[ |
| 622 | + str, DistributedTuningStrategy] = {} |
618 | 623 | self._dist: Optional[Distributed] = None |
| 624 | + self._has_received_cache: bool = False |
619 | 625 | self.mapping: Mapping = Mapping() |
620 | 626 |
|
621 | 627 | @classmethod |
@@ -797,10 +803,17 @@ def choose_one( |
797 | 803 | if self.is_tuning_mode and is_cache_hit: |
798 | 804 | return (runners[best_runner_id], best_tactic) |
799 | 805 |
|
| 806 | + # Reach here means this PP rank does not have cache hit, so we need to wait for a recv |
| 807 | + self.cache_pp_recv() |
| 808 | + |
800 | 809 | assert len(runners) > 0, "At least one runner is required" |
801 | 810 | assert all([isinstance(r, TunableRunner) for r in runners]), \ |
802 | 811 | "All Given runners must be subclass of TunableRunner" |
803 | 812 |
|
| 813 | + # Record the distributed tuning strategy for the custom_op |
| 814 | + self._map_op_to_distributed_strategy[ |
| 815 | + custom_op] = tuning_config.distributed_tuning_strategy |
| 816 | + |
804 | 817 | tuning_start_time = time.perf_counter() |
805 | 818 | profiles = self._optimization_profiles(tuning_config, inputs) |
806 | 819 |
|
@@ -1507,3 +1520,30 @@ def _should_current_rank_tune(self, |
1507 | 1520 | f"[AutoTuner] Unknown distributed tuning strategy: {strategy}, falling back to independent" |
1508 | 1521 | ) |
1509 | 1522 | return True |
| 1523 | + |
| 1524 | + def cache_pp_recv(self): |
| 1525 | + if self.mapping.has_pp() and not self.mapping.is_first_pp_rank( |
| 1526 | + ) and self._dist is not None and not self._has_received_cache: |
| 1527 | + self._debug_logger( |
| 1528 | + f"[AutoTuner] Receiving cache data from previous pp rank {self.mapping.prev_pp_rank()}" |
| 1529 | + ) |
| 1530 | + profiling_cache = self._dist.recv_object( |
| 1531 | + self.mapping.prev_pp_rank(), |
| 1532 | + tag=PP_COMM_TAG_AUTOTUNING, |
| 1533 | + ) |
| 1534 | + # Every time a PP rank only receives a cache data once. After that, it should not do any further receiving |
| 1535 | + self._has_received_cache = True |
| 1536 | + self.profiling_cache.merge_cache_data(profiling_cache) |
| 1537 | + |
| 1538 | + def cache_pp_send(self): |
| 1539 | + # Send all cache contents to the next PP rank |
| 1540 | + if self.mapping.has_pp( |
| 1541 | + ) and not self.mapping.is_last_pp_rank() and self._dist is not None: |
| 1542 | + self._debug_logger( |
| 1543 | + f"[AutoTuner] Sending cache data to next pp rank {self.mapping.next_pp_rank()}" |
| 1544 | + ) |
| 1545 | + self._dist.send_object( |
| 1546 | + self.profiling_cache.cache, |
| 1547 | + self.mapping.next_pp_rank(), |
| 1548 | + tag=PP_COMM_TAG_AUTOTUNING, |
| 1549 | + ) |
0 commit comments