-
Notifications
You must be signed in to change notification settings - Fork 1
avoid copying new_tokens to cpu #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: feat/ad-2025-07-22
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |||||||||||||||||||||||||
| from torch._ops import OpOverloadPacket | ||||||||||||||||||||||||||
| from torch.export import Dim | ||||||||||||||||||||||||||
| from torch.fx import Node | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from tensorrt_llm._utils import nvtx_range | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||
| class CacheConfig: | ||||||||||||||||||||||||||
|
|
@@ -122,7 +122,7 @@ def __post_init__(self): | |||||||||||||||||||||||||
| self.max_batch_size, | ||||||||||||||||||||||||||
| (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) | ||||||||||||||||||||||||||
| self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) | ||||||||||||||||||||||||||
| self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) | ||||||||||||||||||||||||||
| self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) | ||||||||||||||||||||||||||
| self.input_pos = torch.empty_like(self.seq_len) | ||||||||||||||||||||||||||
|
|
@@ -336,7 +336,7 @@ def reset(self) -> None: | |||||||||||||||||||||||||
| self.input_pos.zero_() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # set a dummy sequence corresponding to a generate-only batch (will also reset position_ids) | ||||||||||||||||||||||||||
| self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int)) | ||||||||||||||||||||||||||
| self.nest_sequences([[1]] * self.max_batch_size) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # reset cache information | ||||||||||||||||||||||||||
| self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) | ||||||||||||||||||||||||||
|
|
@@ -381,22 +381,24 @@ def set_generate_only_batch(self) -> None: | |||||||||||||||||||||||||
| self.reset() | ||||||||||||||||||||||||||
| self.nest_sequences([[1]] * self.max_batch_size) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @nvtx_range("ad_update_position_ids") | ||||||||||||||||||||||||||
| def _update_position_ids(self) -> None: | ||||||||||||||||||||||||||
| # set new position_ids as new tensor from input_pos and seq_len via torch.arange | ||||||||||||||||||||||||||
| position_ids_list = [ | ||||||||||||||||||||||||||
| num | ||||||||||||||||||||||||||
| for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths) | ||||||||||||||||||||||||||
| for num in range(in_pos, in_pos + seq_len) | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device) | ||||||||||||||||||||||||||
| self.position_ids = torch.tensor(position_ids_list, dtype=torch.long, pin_memory=True).to(self.device) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] | ||||||||||||||||||||||||||
| if self.is_generate: | ||||||||||||||||||||||||||
| self.position_ids = self.position_ids.view(-1, 1) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| self.position_ids = self.position_ids.view(1, -1) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: | ||||||||||||||||||||||||||
| @nvtx_range("ad_nest_sequences") | ||||||||||||||||||||||||||
| def nest_sequences(self, input_ids: Sequence[Sequence[int]], previous_batch_indices: List[int] = [], new_tokens: Optional[torch.Tensor] = None) -> None: | ||||||||||||||||||||||||||
| """Create and store a flattened list of input_ids from the provided list of sequences. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| This i/f will also update any relevant sequence information. | ||||||||||||||||||||||||||
|
|
@@ -413,8 +415,10 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: | |||||||||||||||||||||||||
| for lst in input_ids | ||||||||||||||||||||||||||
| for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| self.input_ids = torch.tensor(ids_list, dtype=dtype, pin_memory=True).to(self.device) | ||||||||||||||||||||||||||
| if new_tokens is not None: | ||||||||||||||||||||||||||
| self.input_ids[self.input_ids == -1] = new_tokens[0,previous_batch_indices,0] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
Comment on lines
+420
to
+421
|
||||||||||||||||||||||||||
| self.input_ids[self.input_ids == -1] = new_tokens[0,previous_batch_indices,0] | |
| # Validate new_tokens dimensions | |
| if new_tokens.dim() < 3: | |
| raise ValueError(f"new_tokens must have at least 3 dimensions, but got {new_tokens.dim()}.") | |
| # Validate previous_batch_indices | |
| max_index = new_tokens.size(1) - 1 | |
| if any(idx < 0 or idx > max_index for idx in previous_batch_indices): | |
| raise IndexError(f"previous_batch_indices contains out-of-bounds indices for new_tokens' second dimension (valid range: 0 to {max_index}).") | |
| self.input_ids[self.input_ids == -1] = new_tokens[0, previous_batch_indices, 0] |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -101,7 +101,7 @@ def build_from_config(cls, ad_config: AutoDeployConfig): | |||||
| page_size=attn_page_size, | ||||||
| max_num_tokens=max_num_tokens, | ||||||
| ) | ||||||
|
|
||||||
| print(" in seq_info for device: ", torch.cuda.current_device()) | ||||||
|
||||||
| print(" in seq_info for device: ", torch.cuda.current_device()) | |
| ad_logger.info(f"in seq_info for device: {torch.cuda.current_device()}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tensor is created with device=self.device, but self.device is a property that depends on self.input_pos.device. At this point in post_init, input_pos hasn't been initialized yet, which could cause an AttributeError.