|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
1 | 16 | """The transform passes to capture the hidden states of the target model.""" |
2 | 17 |
|
3 | 18 | from typing import Dict, List, Optional, Set, Tuple, Type |
@@ -47,13 +62,8 @@ def cached_residual_add( |
47 | 62 | ) -> torch.Tensor: |
48 | 63 | ret = torch.ops.aten.add(t1, t2) |
49 | 64 | b, s, _ = ret.shape |
50 | | - print(f"In cached residual add. Ret shape: {ret.shape}") |
51 | | - print(f"Shape of hidden_states_cache: {hidden_states_cache.shape}") |
52 | 65 | num_tokens = b * s |
53 | | - print(f"Num tokens: {num_tokens}") |
54 | 66 |
|
55 | | - # TODO(govind): do some of these correspond to padding tokens when there are varying sequence lengths? |
56 | | - # Might need to extract the actual sequence lengths from somewhere to get the appropriate indices to copy. |
57 | 67 | hidden_states_cache[:num_tokens].copy_(ret.view(num_tokens, -1), non_blocking=True) |
58 | 68 | return ret |
59 | 69 |
|
@@ -115,8 +125,6 @@ class DetectHiddenStatesForCaptureConfig(TransformConfig): |
115 | 125 | """Configuration for the hidden states detection transform.""" |
116 | 126 |
|
117 | 127 | # TODO: figure out how to get layers to capture. |
118 | | - # Right now default is None and EagleSpecMetadata has a heuristic to extract layer indices to capture. |
119 | | - # This seems fragile. |
120 | 128 | # We should consider if we can use the layer indices stored in eagle checkpoints, e.g. |
121 | 129 | # https://huggingface.co/nvidia/gpt-oss-120b-Eagle3/blob/main/config.json#L9-L14 |
122 | 130 | eagle3_layers_to_capture: Optional[Set[int]] = None # Default: Do not capture any layers |
@@ -145,15 +153,12 @@ def _apply( |
145 | 153 |
|
146 | 154 | def _get_layer_number(lin_node: Node) -> Optional[int]: |
147 | 155 | weight = lin_node.args[1] |
148 | | - print(f"Calling _get_layer_number() with lin_node: {lin_node}") |
149 | 156 | if weight.op == "get_attr": |
150 | 157 | subnames = weight.target.split(".") |
151 | 158 | for subname in subnames: |
152 | 159 | if subname.isdigit(): |
153 | | - print(f"Found layer number: {int(subname)}") |
154 | 160 | return int(subname) |
155 | 161 |
|
156 | | - print("No layer number found") |
157 | 162 | return None |
158 | 163 |
|
159 | 164 | # find last closing linear node of each layer |
|
0 commit comments