-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][feat] Support using FlexKV as anothor KV Cache Offloading option. #9698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
146e868
ffa48f7
e31c883
1bc42db
de63107
35b1360
fbee447
7fe69f1
15e4ebf
fcf9d9c
8f2fee0
e9ae5b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1667,6 +1667,16 @@ class GenericLlmRequest | |
| [](auto reason) { return reason == executor::FinishReason::kLENGTH; }); | ||
| } | ||
|
|
||
| [[nodiscard]] bool isFinishedNormal() const noexcept | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean by "normal" here? May you modify the function name to make it more expressive? |
||
| { | ||
| return std::all_of(mFinishReasons.begin(), mFinishReasons.end(), | ||
| [](auto reason) | ||
| { | ||
| return reason == executor::FinishReason::kEND_ID || reason == executor::FinishReason::kSTOP_WORDS | ||
| || reason == executor::FinishReason::kLENGTH; | ||
| }); | ||
| } | ||
|
|
||
| [[nodiscard]] bool isTimedOut() const | ||
| { | ||
| if (!mAllottedTimeMs.has_value()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -234,6 +234,13 @@ def update_state_after_alloc(self, request: LlmRequest, | |
| block_ids: The KV cacheblock IDs that were allocated. | ||
| """ | ||
|
|
||
| def wait_for_initialization(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we have |
||
| """ | ||
| Some connectors need to wait for some resources to be initialized. | ||
| For example, FlexKV needs to wait for the FlexKV manager to be initialized. | ||
| """ | ||
| return | ||
|
|
||
|
|
||
| # An internal dataclass to handle async saving/loading requests. | ||
| @dataclass | ||
|
|
@@ -570,3 +577,7 @@ def layer_pre_hook(self, module, *args): | |
|
|
||
| def layer_post_hook(self, module, *args): | ||
| self.worker.save_kv_layer(module.layer_idx, torch.cuda.current_stream()) | ||
|
|
||
| def wait_for_initialization(self): | ||
| if self.scheduler is not None: | ||
| self.scheduler.wait_for_initialization() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -358,6 +358,8 @@ def _maybe_init_kv_connector_manager(self): | |
| module.register_forward_hook( | ||
| self.kv_connector_manager.layer_post_hook) | ||
|
|
||
| self.kv_connector_manager.wait_for_initialization() | ||
|
|
||
| def _event_loop_wrapper(self): | ||
| try: | ||
| with customized_gc_thresholds( | ||
|
|
@@ -610,7 +612,7 @@ def profile_step(): | |
| if prev_device_step_time is None: | ||
| prev_device_step_time = "N/A" # Handle first iteration | ||
| else: | ||
| prev_device_step_time = f"{prev_device_step_time}ms" | ||
| prev_device_step_time = f"{prev_device_step_time} ms" | ||
| host_step_time = (end_time - start_time) * 1000 # milliseconds | ||
| formatted_timestamp = datetime.datetime.now().strftime( | ||
| "%Y-%m-%d %H:%M:%S") | ||
|
|
@@ -620,7 +622,7 @@ def profile_step(): | |
| f"rank = {self.dist.rank}, " | ||
| f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" | ||
| f"{self.executor_request_queue.num_fetch_requests}, " | ||
| f"host_step_time = {host_step_time}ms, " | ||
| f"host_step_time = {host_step_time:.3f} ms, " | ||
| f"prev_device_step_time = {prev_device_step_time}, " | ||
| f"timestamp = {formatted_timestamp}, " | ||
| f"num_scheduled_requests: {self.num_scheduled_requests}, " | ||
|
|
@@ -1302,7 +1304,6 @@ def _kv_connector_start_batch(self, scheduled_batch): | |
| if self.kv_connector_manager: | ||
| self.kv_connector_manager.take_scheduled_requests_pending_load( | ||
| scheduled_batch) | ||
| self.kv_connector_manager.handle_metadata() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May you explain why is this removed? |
||
| self.kv_connector_manager.worker.start_load_kv( | ||
| torch.cuda.current_stream()) | ||
|
|
||
|
|
@@ -1348,6 +1349,10 @@ def _executor_loop(self): | |
| finished_requests = [] | ||
|
|
||
| can_queue = self._can_queue(scheduled_batch) | ||
|
|
||
| if self.kv_connector_manager: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we moving handle_metadata around? This will break all other implementations of the kv connector.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Can the unit tests already added to TRT-LLM detect and ensure that functionality is not broken? |
||
| self.kv_connector_manager.handle_metadata() | ||
|
|
||
| if can_queue: | ||
| if self.kv_cache_transceiver: | ||
| # For generation requests which have completed KV cache transfer | ||
|
|
@@ -1578,6 +1583,9 @@ def _executor_loop_overlap(self): | |
|
|
||
| self._pause_requests(scheduled_batch.paused_requests) | ||
|
|
||
| if self.kv_connector_manager: | ||
| self.kv_connector_manager.handle_metadata() | ||
|
|
||
| can_queue = self._can_queue(scheduled_batch) | ||
| if can_queue: | ||
| if self.kv_cache_transceiver: | ||
|
|
@@ -2634,6 +2642,9 @@ def _handle_responses(self): | |
| self.ctx_in_transmission_counter)) | ||
| else: | ||
| requests_to_terminate.append(request) | ||
|
|
||
| if self.kv_connector_manager is not None: | ||
| self.resource_manager.free_slot_only(request) | ||
| else: | ||
| new_active_requests.append(request) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -177,6 +177,7 @@ def _get_mapping(_mapping: Mapping) -> Mapping: | |
| tp_size=tensorrt_llm.mpi_world_size(), | ||
| gpus_per_node=tensorrt_llm.default_gpus_per_node(), | ||
| rank=tensorrt_llm.mpi_rank()) | ||
| executor_config.mapping = mapping | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the code change should keep in mind that other types of connector exist too. Overwriting executor_config.mapping here may not be the intent of other Connectors. Maybe you can guard this exclusively for FlexKV connector? |
||
| else: | ||
| mapping = copy.deepcopy(_mapping) | ||
| mapping.rank = tensorrt_llm.mpi_rank() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this used? Is this field (via the bindings) only accessed inside the flexKV implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, currently it’s only accessed inside the flexKV implementation, which can be found here. Is this acceptable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any way to infer that value from the other fields? We ideally don't include new fields in TRTLLM that are exclusively accessed by a specific kv connector implementation.