Skip to content

Commit be8cd6c

Browse files
revert _compute_logits() args
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 792031e commit be8cd6c

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -588,18 +588,10 @@ def _build_input_ids(request) -> Tuple[List[int], List[int]]:
588588
return last_logit_only
589589

590590
@nvtx_range("ad_compute_logits")
591-
def _compute_logits(self, resource_manager: ResourceManager) -> List[torch.Tensor]:
591+
def _compute_logits(self) -> List[torch.Tensor]:
592592
# run the model
593593
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
594594

595-
spec_resource_manager = resource_manager.get_resource_manager(
596-
ResourceManagerType.SPEC_RESOURCE_MANAGER
597-
)
598-
if spec_resource_manager is not None and isinstance(
599-
spec_resource_manager, ADHiddenStateManager
600-
):
601-
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
602-
603595
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
604596
logits = logits.float()
605597

@@ -626,7 +618,16 @@ def forward(
626618
self.iter_counter += 1
627619

628620
# compute all logits
629-
logits = self._compute_logits(resource_manager)
621+
logits = self._compute_logits()
622+
623+
# save hidden states after running model.forward() in _compute_logits()
624+
spec_resource_manager = resource_manager.get_resource_manager(
625+
ResourceManagerType.SPEC_RESOURCE_MANAGER
626+
)
627+
if spec_resource_manager is not None and isinstance(
628+
spec_resource_manager, ADHiddenStateManager
629+
):
630+
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
630631

631632
# gather+cat logits
632633
logits_flat = torch.cat(

0 commit comments

Comments
 (0)