Skip to content

Commit e1f56c4

Browse files
authored
feat: add diagnostic script for problematic embeddings (#896)
Signed-off-by: Terry Kong <[email protected]>
1 parent 223bfa8 commit e1f56c4

File tree

2 files changed

+327
-0
lines changed

2 files changed

+327
-0
lines changed

docs/adding-new-models.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,42 @@ uv run --extra vllm tools/model_diagnostics/2.long_generation_decode_vs_prefill.
152152
# ...
153153
# [Qwen/Qwen2.5-1.5B] ALL GOOD!
154154
```
155+
156+
## [3.check_hf_model_embeddings_untrained.py](https://github.com/NVIDIA-NeMo/RL/blob/main/tools/model_diagnostics/3.check_hf_model_embeddings_untrained.py)
157+
158+
Detects untrained or improperly initialized Hugging Face model embeddings by scanning for near-zero rows and rows with near-identical values in both input and output embeddings. The script also reports whether word embeddings are tied and summarizes basic statistics.
159+
160+
```sh
161+
# Example run
162+
uv run --extra mcore tools/model_diagnostics/3.check_hf_model_embeddings_untrained.py --model nvidia/Nemotron-H-8B-Base-8K
163+
164+
# ....
165+
#================================================================================
166+
#EMBEDDING SUMMARIES
167+
#================================================================================
168+
#
169+
#--- Input Embeddings Summary ---
170+
#Shape: torch.Size([131072, 4096]), Dtype: torch.bfloat16
171+
#Near-zero embeddings (abs < 1.00e-10): 1039/131072 (0.8%)
172+
# Indices: 0-1,3-999,1192-1193,1245-1255,55014,77579,81772,81819,82312,82500,82725,82737,82977,84020,84121,84521,84794,85015,86409,87411,89412,90320,91368,94485,96385,104097,108262,112147,112327,112497,114755
173+
#Identical embeddings (std < 1.00e-08): 1041/131072 (0.8%)
174+
# Indices: 0-1,3-999,1192-1193,1245-1255,55014,77579,81772,81819,82312,82500,82725,82737,82977,83855,84020,84121,84521,84794,85015,86409,87411,89412,90320,91368,94485,96385,101707,104097,108262,112147,112327,112497,114755
175+
#Statistics: mean_abs=0.007874, max_abs=0.196289, std_range=[0.000000, 0.015442]
176+
#⚠️ POTENTIAL ISSUES: 1039 near-zero embeddings, 1041 identical embeddings
177+
#
178+
#--- Output Embeddings Summary (Tied: False) ---
179+
#Shape: torch.Size([131072, 4096]), Dtype: torch.bfloat16
180+
#Near-zero embeddings (abs < 1.00e-10): 0/131072 (0.0%)
181+
#Identical embeddings (std < 1.00e-08): 0/131072 (0.0%)
182+
#Statistics: mean_abs=0.006775, max_abs=0.200195, std_range=[0.004089, 0.021240]
183+
#✅ No obvious untrained patterns detected
184+
#
185+
#=== Final Summary ===
186+
#Model: nvidia/Nemotron-H-8B-Base-8K
187+
#Analysis complete.
188+
```
189+
190+
- Thresholds can be adjusted via flags:
191+
- `--near-zero-threshold` (default: `1e-10`)
192+
- `--identical-threshold` (default: `1e-8`)
193+
- If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training.
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Model Diagnostic: Check HuggingFace Model Embeddings for Untrained Patterns.
16+
17+
This script loads a HuggingFace model and analyzes the input and output embeddings
18+
to detect patterns that suggest the model may be untrained or improperly initialized.
19+
20+
uv run --extra mcore 3.check_hf_model_embeddings_untrained.py --model nvidia/Nemotron-H-8B-Base-8K
21+
"""
22+
23+
import argparse
24+
25+
import torch
26+
from transformers import AutoModelForCausalLM, AutoTokenizer
27+
28+
29+
def format_index_ranges(indices):
30+
"""Format a list of indices into range strings like '0-1,3-6'."""
31+
if not indices:
32+
return ""
33+
34+
ranges = []
35+
start = end = indices[0]
36+
37+
for i in range(1, len(indices)):
38+
if indices[i] == end + 1:
39+
end = indices[i]
40+
else:
41+
ranges.append(str(start) if start == end else f"{start}-{end}")
42+
start = end = indices[i]
43+
44+
# Add the last range
45+
ranges.append(str(start) if start == end else f"{start}-{end}")
46+
return ",".join(ranges)
47+
48+
49+
def get_token_info(tokenizer, idx):
50+
"""Get token information for a given index."""
51+
if not tokenizer:
52+
return "N/A"
53+
try:
54+
return repr(tokenizer.decode([idx]))
55+
except Exception:
56+
return "N/A"
57+
58+
59+
def print_problematic_embeddings(
60+
weights, indices, problem_type, metric_values, threshold, tokenizer=None
61+
):
62+
"""Print detailed information about each problematic embedding."""
63+
if not indices:
64+
return
65+
66+
print(f"\n--- Detailed {problem_type} Embeddings ---")
67+
for idx in indices:
68+
embedding = weights[idx]
69+
metric_val = metric_values[idx].item()
70+
token_info = get_token_info(tokenizer, idx)
71+
72+
# Get first 2 and last 2 values
73+
first_two = embedding[:2].tolist()
74+
last_two = embedding[-2:].tolist()
75+
76+
print(
77+
f"Index {idx}: {problem_type} (metric: {metric_val:.2e} > {threshold:.2e})"
78+
)
79+
print(f" Token: {token_info}")
80+
print(
81+
f" Values: [{first_two[0]:.2e}, {first_two[1]:.2e}, ..., {last_two[0]:.2e}, {last_two[1]:.2e}]"
82+
)
83+
84+
85+
def find_output_embeddings(model):
86+
"""Find the output embeddings layer in various model architectures."""
87+
if hasattr(model, "get_output_embeddings"):
88+
return model.get_output_embeddings()
89+
elif hasattr(model, "lm_head"):
90+
return model.lm_head
91+
elif hasattr(model, "embed_out"):
92+
return model.embed_out
93+
return None
94+
95+
96+
def check_embedding_layer(
97+
embeddings,
98+
layer_name,
99+
near_zero_threshold,
100+
identical_threshold,
101+
tokenizer=None,
102+
model=None,
103+
):
104+
"""Check an embedding layer for untrained patterns."""
105+
print(f"\n=== {layer_name} Analysis ===")
106+
107+
# Check if embeddings are tied (for output embeddings)
108+
tied_info = ""
109+
if layer_name == "Output Embeddings" and model and hasattr(model, "config"):
110+
tied = getattr(model.config, "tie_word_embeddings", False)
111+
tied_info = f" (Tied: {tied})"
112+
print(f"Tied word embeddings: {tied}")
113+
114+
# Get embedding weights
115+
weights = (
116+
embeddings.weight.data if hasattr(embeddings, "weight") else embeddings.data
117+
)
118+
119+
print(f"Shape: {weights.shape}")
120+
print(f"Dtype: {weights.dtype}")
121+
122+
# Check for near-zero embeddings
123+
near_zero_mask = torch.abs(weights) < near_zero_threshold
124+
near_zero_rows = near_zero_mask.all(dim=1)
125+
near_zero_indices = torch.where(near_zero_rows)[0].tolist()
126+
127+
# Check for identical embeddings using standard deviation
128+
row_stds = weights.std(dim=1)
129+
identical_mask = row_stds < identical_threshold
130+
identical_indices = torch.where(identical_mask)[0].tolist()
131+
132+
# Print detailed problematic embeddings
133+
max_abs_values = torch.abs(weights).max(dim=1)[0]
134+
print_problematic_embeddings(
135+
weights,
136+
near_zero_indices,
137+
"Near-zero",
138+
max_abs_values,
139+
near_zero_threshold,
140+
tokenizer,
141+
)
142+
print_problematic_embeddings(
143+
weights,
144+
identical_indices,
145+
"Identical",
146+
row_stds,
147+
identical_threshold,
148+
tokenizer,
149+
)
150+
151+
# Return summary data instead of printing
152+
num_near_zero = len(near_zero_indices)
153+
num_identical = len(identical_indices)
154+
total_embeddings = weights.shape[0]
155+
156+
# Flag potential issues
157+
issues = []
158+
if num_near_zero > 0:
159+
issues.append(f"{num_near_zero} near-zero embeddings")
160+
if num_identical > 0:
161+
issues.append(f"{num_identical} identical embeddings")
162+
163+
return {
164+
"layer_name": layer_name,
165+
"tied_info": tied_info,
166+
"shape": weights.shape,
167+
"dtype": weights.dtype,
168+
"num_near_zero": num_near_zero,
169+
"num_identical": num_identical,
170+
"total_embeddings": total_embeddings,
171+
"near_zero_indices": near_zero_indices,
172+
"identical_indices": identical_indices,
173+
"near_zero_threshold": near_zero_threshold,
174+
"identical_threshold": identical_threshold,
175+
"mean_abs": torch.abs(weights).mean().item(),
176+
"max_abs": torch.abs(weights).max().item(),
177+
"min_std": row_stds.min().item(),
178+
"max_std": row_stds.max().item(),
179+
"issues": issues,
180+
}
181+
182+
183+
def main():
184+
parser = argparse.ArgumentParser(
185+
description="Check HuggingFace model embeddings for untrained patterns"
186+
)
187+
parser.add_argument(
188+
"--model",
189+
default="nvidia/Nemotron-H-8B-Base-8K",
190+
help="HuggingFace model name or path",
191+
)
192+
parser.add_argument(
193+
"--near-zero-threshold",
194+
type=float,
195+
default=1e-10,
196+
help="Threshold for detecting near-zero embeddings (default: 1e-10)",
197+
)
198+
parser.add_argument(
199+
"--identical-threshold",
200+
type=float,
201+
default=1e-8,
202+
help="Threshold for detecting identical embeddings via std dev (default: 1e-8)",
203+
)
204+
205+
args = parser.parse_args()
206+
207+
print(f"Loading model: {args.model}")
208+
209+
# Load model and tokenizer
210+
model = AutoModelForCausalLM.from_pretrained(
211+
args.model, torch_dtype="auto", trust_remote_code=True
212+
)
213+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
214+
215+
print("Model loaded successfully")
216+
print(f"Model type: {type(model).__name__}")
217+
print(f"Vocabulary size: {len(tokenizer)}")
218+
219+
# Collect summary data from both embeddings
220+
summaries = []
221+
222+
# Check input embeddings
223+
input_embeddings = model.get_input_embeddings()
224+
if input_embeddings is not None:
225+
input_summary = check_embedding_layer(
226+
input_embeddings,
227+
"Input Embeddings",
228+
args.near_zero_threshold,
229+
args.identical_threshold,
230+
tokenizer,
231+
model,
232+
)
233+
summaries.append(input_summary)
234+
else:
235+
print("\n⚠️ Could not find input embeddings layer")
236+
237+
# Check output embeddings
238+
output_embeddings = find_output_embeddings(model)
239+
if output_embeddings is not None:
240+
output_summary = check_embedding_layer(
241+
output_embeddings,
242+
"Output Embeddings",
243+
args.near_zero_threshold,
244+
args.identical_threshold,
245+
tokenizer,
246+
model,
247+
)
248+
summaries.append(output_summary)
249+
else:
250+
print("\n⚠️ Could not find output embeddings layer")
251+
252+
# Print summaries together
253+
print("\n" + "=" * 80)
254+
print("EMBEDDING SUMMARIES")
255+
print("=" * 80)
256+
257+
for summary in summaries:
258+
print(f"\n--- {summary['layer_name']} Summary{summary['tied_info']} ---")
259+
print(f"Shape: {summary['shape']}, Dtype: {summary['dtype']}")
260+
261+
print(
262+
f"Near-zero embeddings (abs < {summary['near_zero_threshold']:.2e}): {summary['num_near_zero']}/{summary['total_embeddings']} ({100 * summary['num_near_zero'] / summary['total_embeddings']:.1f}%)"
263+
)
264+
if summary["near_zero_indices"]:
265+
print(f" Indices: {format_index_ranges(summary['near_zero_indices'])}")
266+
267+
print(
268+
f"Identical embeddings (std < {summary['identical_threshold']:.2e}): {summary['num_identical']}/{summary['total_embeddings']} ({100 * summary['num_identical'] / summary['total_embeddings']:.1f}%)"
269+
)
270+
if summary["identical_indices"]:
271+
print(f" Indices: {format_index_ranges(summary['identical_indices'])}")
272+
273+
print(
274+
f"Statistics: mean_abs={summary['mean_abs']:.6f}, max_abs={summary['max_abs']:.6f}, std_range=[{summary['min_std']:.6f}, {summary['max_std']:.6f}]"
275+
)
276+
277+
if summary["issues"]:
278+
print(f"⚠️ POTENTIAL ISSUES: {', '.join(summary['issues'])}")
279+
else:
280+
print("✅ No obvious untrained patterns detected")
281+
282+
print("\n=== Final Summary ===")
283+
print(f"Model: {args.model}")
284+
print("Analysis complete.")
285+
286+
287+
if __name__ == "__main__":
288+
main()

0 commit comments

Comments
 (0)