Skip to content

Commit 7287d79

Browse files
revert _compute_logits() args
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 602fa0e commit 7287d79

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -747,18 +747,10 @@ def _build_input_ids(request) -> Tuple[List[int], List[int], bool]:
747747
return last_logit_only
748748

749749
@nvtx_range("ad_compute_logits")
750-
def _compute_logits(self, resource_manager: ResourceManager) -> List[torch.Tensor]:
750+
def _compute_logits(self) -> List[torch.Tensor]:
751751
# run the model
752752
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
753753

754-
spec_resource_manager = resource_manager.get_resource_manager(
755-
ResourceManagerType.SPEC_RESOURCE_MANAGER
756-
)
757-
if spec_resource_manager is not None and isinstance(
758-
spec_resource_manager, ADHiddenStateManager
759-
):
760-
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
761-
762754
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
763755
logits = logits.float()
764756

@@ -786,7 +778,16 @@ def forward(
786778
self.iter_counter += 1
787779

788780
# compute all logits
789-
logits = self._compute_logits(resource_manager)
781+
logits = self._compute_logits()
782+
783+
# save hidden states after running model.forward() in _compute_logits()
784+
spec_resource_manager = resource_manager.get_resource_manager(
785+
ResourceManagerType.SPEC_RESOURCE_MANAGER
786+
)
787+
if spec_resource_manager is not None and isinstance(
788+
spec_resource_manager, ADHiddenStateManager
789+
):
790+
spec_resource_manager.capture_hidden_states(self.cache_seq_interface)
790791

791792
# gather+cat logits
792793
logits_flat = torch.cat(

0 commit comments

Comments
 (0)