Skip to content

Commit 17029de

Browse files
removed some prints and added license
Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com>
1 parent 0615b0e commit 17029de

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def capture_hidden_states(self, cache_seq_interface: CachedSequenceInterface) ->
188188
num_tokens = sum(cache_seq_interface.info.seq_len)
189189

190190
hidden_states = [hidden_state[:num_tokens] for hidden_state in full_hidden_states]
191-
hidden_states = torch.cat(hidden_states, dim=1) if hidden_states else None
191+
hidden_states = torch.cat(hidden_states, dim=1)
192192
hidden_states = hidden_states.to(dtype=self.dtype)
193193

194194
token_idx = self.hidden_state_write_indices[:num_tokens]

tensorrt_llm/_torch/auto_deploy/transform/library/hidden_states.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
"""The transform passes to capture the hidden states of the target model."""
217

318
from typing import Dict, List, Optional, Set, Tuple, Type
@@ -47,13 +62,8 @@ def cached_residual_add(
4762
) -> torch.Tensor:
4863
ret = torch.ops.aten.add(t1, t2)
4964
b, s, _ = ret.shape
50-
print(f"In cached residual add. Ret shape: {ret.shape}")
51-
print(f"Shape of hidden_states_cache: {hidden_states_cache.shape}")
5265
num_tokens = b * s
53-
print(f"Num tokens: {num_tokens}")
5466

55-
# TODO(govind): do some of these correspond to padding tokens when there are varying sequence lengths?
56-
# Might need to extract the actual sequence lengths from somewhere to get the appropriate indices to copy.
5767
hidden_states_cache[:num_tokens].copy_(ret.view(num_tokens, -1), non_blocking=True)
5868
return ret
5969

@@ -115,8 +125,6 @@ class DetectHiddenStatesForCaptureConfig(TransformConfig):
115125
"""Configuration for the hidden states detection transform."""
116126

117127
# TODO: figure out how to get layers to capture.
118-
# Right now default is None and EagleSpecMetadata has a heuristic to extract layer indices to capture.
119-
# This seems fragile.
120128
# We should consider if we can use the layer indices stored in eagle checkpoints, e.g.
121129
# https://huggingface.co/nvidia/gpt-oss-120b-Eagle3/blob/main/config.json#L9-L14
122130
eagle3_layers_to_capture: Optional[Set[int]] = None # Default: Do not capture any layers
@@ -145,15 +153,12 @@ def _apply(
145153

146154
def _get_layer_number(lin_node: Node) -> Optional[int]:
147155
weight = lin_node.args[1]
148-
print(f"Calling _get_layer_number() with lin_node: {lin_node}")
149156
if weight.op == "get_attr":
150157
subnames = weight.target.split(".")
151158
for subname in subnames:
152159
if subname.isdigit():
153-
print(f"Found layer number: {int(subname)}")
154160
return int(subname)
155161

156-
print("No layer number found")
157162
return None
158163

159164
# find last closing linear node of each layer

0 commit comments

Comments
 (0)