File tree Expand file tree Collapse file tree 1 file changed +11
-10
lines changed
tensorrt_llm/_torch/auto_deploy/shim Expand file tree Collapse file tree 1 file changed +11
-10
lines changed Original file line number Diff line number Diff line change @@ -588,18 +588,10 @@ def _build_input_ids(request) -> Tuple[List[int], List[int]]:
588588 return last_logit_only
589589
590590 @nvtx_range ("ad_compute_logits" )
591- def _compute_logits (self , resource_manager : ResourceManager ) -> List [torch .Tensor ]:
591+ def _compute_logits (self ) -> List [torch .Tensor ]:
592592 # run the model
593593 logits : torch .Tensor = self .model (** self .cache_seq_interface .named_args )[0 ]
594594
595- spec_resource_manager = resource_manager .get_resource_manager (
596- ResourceManagerType .SPEC_RESOURCE_MANAGER
597- )
598- if spec_resource_manager is not None and isinstance (
599- spec_resource_manager , ADHiddenStateManager
600- ):
601- spec_resource_manager .capture_hidden_states (self .cache_seq_interface )
602-
603595 # TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
604596 logits = logits .float ()
605597
@@ -626,7 +618,16 @@ def forward(
626618 self .iter_counter += 1
627619
628620 # compute all logits
629- logits = self ._compute_logits (resource_manager )
621+ logits = self ._compute_logits ()
622+
623+ # save hidden states after running model.forward() in _compute_logits()
624+ spec_resource_manager = resource_manager .get_resource_manager (
625+ ResourceManagerType .SPEC_RESOURCE_MANAGER
626+ )
627+ if spec_resource_manager is not None and isinstance (
628+ spec_resource_manager , ADHiddenStateManager
629+ ):
630+ spec_resource_manager .capture_hidden_states (self .cache_seq_interface )
630631
631632 # gather+cat logits
632633 logits_flat = torch .cat (
You can’t perform that action at this time.
0 commit comments