|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""Model Diagnostic: Check HuggingFace Model Embeddings for Untrained Patterns. |
| 16 | +
|
| 17 | +This script loads a HuggingFace model and analyzes the input and output embeddings |
| 18 | +to detect patterns that suggest the model may be untrained or improperly initialized. |
| 19 | +
|
| 20 | +uv run --extra mcore 3.check_hf_model_embeddings_untrained.py --model nvidia/Nemotron-H-8B-Base-8K |
| 21 | +""" |
| 22 | + |
| 23 | +import argparse |
| 24 | + |
| 25 | +import torch |
| 26 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 27 | + |
| 28 | + |
| 29 | +def format_index_ranges(indices): |
| 30 | + """Format a list of indices into range strings like '0-1,3-6'.""" |
| 31 | + if not indices: |
| 32 | + return "" |
| 33 | + |
| 34 | + ranges = [] |
| 35 | + start = end = indices[0] |
| 36 | + |
| 37 | + for i in range(1, len(indices)): |
| 38 | + if indices[i] == end + 1: |
| 39 | + end = indices[i] |
| 40 | + else: |
| 41 | + ranges.append(str(start) if start == end else f"{start}-{end}") |
| 42 | + start = end = indices[i] |
| 43 | + |
| 44 | + # Add the last range |
| 45 | + ranges.append(str(start) if start == end else f"{start}-{end}") |
| 46 | + return ",".join(ranges) |
| 47 | + |
| 48 | + |
| 49 | +def get_token_info(tokenizer, idx): |
| 50 | + """Get token information for a given index.""" |
| 51 | + if not tokenizer: |
| 52 | + return "N/A" |
| 53 | + try: |
| 54 | + return repr(tokenizer.decode([idx])) |
| 55 | + except Exception: |
| 56 | + return "N/A" |
| 57 | + |
| 58 | + |
| 59 | +def print_problematic_embeddings( |
| 60 | + weights, indices, problem_type, metric_values, threshold, tokenizer=None |
| 61 | +): |
| 62 | + """Print detailed information about each problematic embedding.""" |
| 63 | + if not indices: |
| 64 | + return |
| 65 | + |
| 66 | + print(f"\n--- Detailed {problem_type} Embeddings ---") |
| 67 | + for idx in indices: |
| 68 | + embedding = weights[idx] |
| 69 | + metric_val = metric_values[idx].item() |
| 70 | + token_info = get_token_info(tokenizer, idx) |
| 71 | + |
| 72 | + # Get first 2 and last 2 values |
| 73 | + first_two = embedding[:2].tolist() |
| 74 | + last_two = embedding[-2:].tolist() |
| 75 | + |
| 76 | + print( |
| 77 | + f"Index {idx}: {problem_type} (metric: {metric_val:.2e} > {threshold:.2e})" |
| 78 | + ) |
| 79 | + print(f" Token: {token_info}") |
| 80 | + print( |
| 81 | + f" Values: [{first_two[0]:.2e}, {first_two[1]:.2e}, ..., {last_two[0]:.2e}, {last_two[1]:.2e}]" |
| 82 | + ) |
| 83 | + |
| 84 | + |
| 85 | +def find_output_embeddings(model): |
| 86 | + """Find the output embeddings layer in various model architectures.""" |
| 87 | + if hasattr(model, "get_output_embeddings"): |
| 88 | + return model.get_output_embeddings() |
| 89 | + elif hasattr(model, "lm_head"): |
| 90 | + return model.lm_head |
| 91 | + elif hasattr(model, "embed_out"): |
| 92 | + return model.embed_out |
| 93 | + return None |
| 94 | + |
| 95 | + |
| 96 | +def check_embedding_layer( |
| 97 | + embeddings, |
| 98 | + layer_name, |
| 99 | + near_zero_threshold, |
| 100 | + identical_threshold, |
| 101 | + tokenizer=None, |
| 102 | + model=None, |
| 103 | +): |
| 104 | + """Check an embedding layer for untrained patterns.""" |
| 105 | + print(f"\n=== {layer_name} Analysis ===") |
| 106 | + |
| 107 | + # Check if embeddings are tied (for output embeddings) |
| 108 | + tied_info = "" |
| 109 | + if layer_name == "Output Embeddings" and model and hasattr(model, "config"): |
| 110 | + tied = getattr(model.config, "tie_word_embeddings", False) |
| 111 | + tied_info = f" (Tied: {tied})" |
| 112 | + print(f"Tied word embeddings: {tied}") |
| 113 | + |
| 114 | + # Get embedding weights |
| 115 | + weights = ( |
| 116 | + embeddings.weight.data if hasattr(embeddings, "weight") else embeddings.data |
| 117 | + ) |
| 118 | + |
| 119 | + print(f"Shape: {weights.shape}") |
| 120 | + print(f"Dtype: {weights.dtype}") |
| 121 | + |
| 122 | + # Check for near-zero embeddings |
| 123 | + near_zero_mask = torch.abs(weights) < near_zero_threshold |
| 124 | + near_zero_rows = near_zero_mask.all(dim=1) |
| 125 | + near_zero_indices = torch.where(near_zero_rows)[0].tolist() |
| 126 | + |
| 127 | + # Check for identical embeddings using standard deviation |
| 128 | + row_stds = weights.std(dim=1) |
| 129 | + identical_mask = row_stds < identical_threshold |
| 130 | + identical_indices = torch.where(identical_mask)[0].tolist() |
| 131 | + |
| 132 | + # Print detailed problematic embeddings |
| 133 | + max_abs_values = torch.abs(weights).max(dim=1)[0] |
| 134 | + print_problematic_embeddings( |
| 135 | + weights, |
| 136 | + near_zero_indices, |
| 137 | + "Near-zero", |
| 138 | + max_abs_values, |
| 139 | + near_zero_threshold, |
| 140 | + tokenizer, |
| 141 | + ) |
| 142 | + print_problematic_embeddings( |
| 143 | + weights, |
| 144 | + identical_indices, |
| 145 | + "Identical", |
| 146 | + row_stds, |
| 147 | + identical_threshold, |
| 148 | + tokenizer, |
| 149 | + ) |
| 150 | + |
| 151 | + # Return summary data instead of printing |
| 152 | + num_near_zero = len(near_zero_indices) |
| 153 | + num_identical = len(identical_indices) |
| 154 | + total_embeddings = weights.shape[0] |
| 155 | + |
| 156 | + # Flag potential issues |
| 157 | + issues = [] |
| 158 | + if num_near_zero > 0: |
| 159 | + issues.append(f"{num_near_zero} near-zero embeddings") |
| 160 | + if num_identical > 0: |
| 161 | + issues.append(f"{num_identical} identical embeddings") |
| 162 | + |
| 163 | + return { |
| 164 | + "layer_name": layer_name, |
| 165 | + "tied_info": tied_info, |
| 166 | + "shape": weights.shape, |
| 167 | + "dtype": weights.dtype, |
| 168 | + "num_near_zero": num_near_zero, |
| 169 | + "num_identical": num_identical, |
| 170 | + "total_embeddings": total_embeddings, |
| 171 | + "near_zero_indices": near_zero_indices, |
| 172 | + "identical_indices": identical_indices, |
| 173 | + "near_zero_threshold": near_zero_threshold, |
| 174 | + "identical_threshold": identical_threshold, |
| 175 | + "mean_abs": torch.abs(weights).mean().item(), |
| 176 | + "max_abs": torch.abs(weights).max().item(), |
| 177 | + "min_std": row_stds.min().item(), |
| 178 | + "max_std": row_stds.max().item(), |
| 179 | + "issues": issues, |
| 180 | + } |
| 181 | + |
| 182 | + |
| 183 | +def main(): |
| 184 | + parser = argparse.ArgumentParser( |
| 185 | + description="Check HuggingFace model embeddings for untrained patterns" |
| 186 | + ) |
| 187 | + parser.add_argument( |
| 188 | + "--model", |
| 189 | + default="nvidia/Nemotron-H-8B-Base-8K", |
| 190 | + help="HuggingFace model name or path", |
| 191 | + ) |
| 192 | + parser.add_argument( |
| 193 | + "--near-zero-threshold", |
| 194 | + type=float, |
| 195 | + default=1e-10, |
| 196 | + help="Threshold for detecting near-zero embeddings (default: 1e-10)", |
| 197 | + ) |
| 198 | + parser.add_argument( |
| 199 | + "--identical-threshold", |
| 200 | + type=float, |
| 201 | + default=1e-8, |
| 202 | + help="Threshold for detecting identical embeddings via std dev (default: 1e-8)", |
| 203 | + ) |
| 204 | + |
| 205 | + args = parser.parse_args() |
| 206 | + |
| 207 | + print(f"Loading model: {args.model}") |
| 208 | + |
| 209 | + # Load model and tokenizer |
| 210 | + model = AutoModelForCausalLM.from_pretrained( |
| 211 | + args.model, torch_dtype="auto", trust_remote_code=True |
| 212 | + ) |
| 213 | + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) |
| 214 | + |
| 215 | + print("Model loaded successfully") |
| 216 | + print(f"Model type: {type(model).__name__}") |
| 217 | + print(f"Vocabulary size: {len(tokenizer)}") |
| 218 | + |
| 219 | + # Collect summary data from both embeddings |
| 220 | + summaries = [] |
| 221 | + |
| 222 | + # Check input embeddings |
| 223 | + input_embeddings = model.get_input_embeddings() |
| 224 | + if input_embeddings is not None: |
| 225 | + input_summary = check_embedding_layer( |
| 226 | + input_embeddings, |
| 227 | + "Input Embeddings", |
| 228 | + args.near_zero_threshold, |
| 229 | + args.identical_threshold, |
| 230 | + tokenizer, |
| 231 | + model, |
| 232 | + ) |
| 233 | + summaries.append(input_summary) |
| 234 | + else: |
| 235 | + print("\n⚠️ Could not find input embeddings layer") |
| 236 | + |
| 237 | + # Check output embeddings |
| 238 | + output_embeddings = find_output_embeddings(model) |
| 239 | + if output_embeddings is not None: |
| 240 | + output_summary = check_embedding_layer( |
| 241 | + output_embeddings, |
| 242 | + "Output Embeddings", |
| 243 | + args.near_zero_threshold, |
| 244 | + args.identical_threshold, |
| 245 | + tokenizer, |
| 246 | + model, |
| 247 | + ) |
| 248 | + summaries.append(output_summary) |
| 249 | + else: |
| 250 | + print("\n⚠️ Could not find output embeddings layer") |
| 251 | + |
| 252 | + # Print summaries together |
| 253 | + print("\n" + "=" * 80) |
| 254 | + print("EMBEDDING SUMMARIES") |
| 255 | + print("=" * 80) |
| 256 | + |
| 257 | + for summary in summaries: |
| 258 | + print(f"\n--- {summary['layer_name']} Summary{summary['tied_info']} ---") |
| 259 | + print(f"Shape: {summary['shape']}, Dtype: {summary['dtype']}") |
| 260 | + |
| 261 | + print( |
| 262 | + f"Near-zero embeddings (abs < {summary['near_zero_threshold']:.2e}): {summary['num_near_zero']}/{summary['total_embeddings']} ({100 * summary['num_near_zero'] / summary['total_embeddings']:.1f}%)" |
| 263 | + ) |
| 264 | + if summary["near_zero_indices"]: |
| 265 | + print(f" Indices: {format_index_ranges(summary['near_zero_indices'])}") |
| 266 | + |
| 267 | + print( |
| 268 | + f"Identical embeddings (std < {summary['identical_threshold']:.2e}): {summary['num_identical']}/{summary['total_embeddings']} ({100 * summary['num_identical'] / summary['total_embeddings']:.1f}%)" |
| 269 | + ) |
| 270 | + if summary["identical_indices"]: |
| 271 | + print(f" Indices: {format_index_ranges(summary['identical_indices'])}") |
| 272 | + |
| 273 | + print( |
| 274 | + f"Statistics: mean_abs={summary['mean_abs']:.6f}, max_abs={summary['max_abs']:.6f}, std_range=[{summary['min_std']:.6f}, {summary['max_std']:.6f}]" |
| 275 | + ) |
| 276 | + |
| 277 | + if summary["issues"]: |
| 278 | + print(f"⚠️ POTENTIAL ISSUES: {', '.join(summary['issues'])}") |
| 279 | + else: |
| 280 | + print("✅ No obvious untrained patterns detected") |
| 281 | + |
| 282 | + print("\n=== Final Summary ===") |
| 283 | + print(f"Model: {args.model}") |
| 284 | + print("Analysis complete.") |
| 285 | + |
| 286 | + |
| 287 | +if __name__ == "__main__": |
| 288 | + main() |
0 commit comments