Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
e1c7506
Merge pull request #945 from TransformerLensOrg/dev
bryce13950 Jun 12, 2025
a634e57
Merge pull request #952 from TransformerLensOrg/dev
bryce13950 Jun 19, 2025
50ee38b
Merge pull request #958 from TransformerLensOrg/dev
bryce13950 Jul 9, 2025
f7e76b6
Create hubert_block.py
david-wei-01001 Nov 5, 2025
926c1c4
Delete transformer_lens/components/hubert_block.py
david-wei-01001 Nov 5, 2025
95d50bc
Create HookedAudioEncoder.py
david-wei-01001 Nov 5, 2025
9a29295
Update HookedAudioEncoder.py
david-wei-01001 Nov 5, 2025
48c6efe
Update HookedAudioEncoder.py
david-wei-01001 Nov 5, 2025
1b7559c
Update HookedAudioEncoder.py
david-wei-01001 Nov 5, 2025
94fa33e
Update HookedAudioEncoder.py
david-wei-01001 Nov 5, 2025
6e93a5b
Update loading_from_pretrained.py
david-wei-01001 Nov 5, 2025
4edde8d
Update HookedAudioEncoder.py
david-wei-01001 Nov 5, 2025
a5ef321
Update loading_from_pretrained.py
david-wei-01001 Nov 5, 2025
cd930f3
Update loading_from_pretrained.py
david-wei-01001 Nov 5, 2025
4621730
Update loading_from_pretrained.py
david-wei-01001 Nov 5, 2025
548e693
Create hubert.py
david-wei-01001 Nov 5, 2025
5dc88a1
Update HookedAudioEncoder.py
david-wei-01001 Nov 5, 2025
8282805
Update HookedAudioEncoder.py
david-wei-01001 Nov 6, 2025
7f0c373
Create hubert_test.py
david-wei-01001 Nov 6, 2025
e8bbf84
Update hubert_test.py
david-wei-01001 Nov 6, 2025
86ac1d9
Update HookedAudioEncoder.py
david-wei-01001 Nov 6, 2025
8f1b889
Create hubert_ctc_test.py
david-wei-01001 Nov 6, 2025
afc2a35
Update HookedAudioEncoder.py
david-wei-01001 Nov 6, 2025
f94fa40
Create hubert_hook_test.py
david-wei-01001 Nov 6, 2025
cff50b3
Update hubert_hook_test.py
david-wei-01001 Nov 6, 2025
764810a
done
david-wei-01001 Nov 7, 2025
7e844a3
done
david-wei-01001 Nov 7, 2025
9a6bc7a
done
david-wei-01001 Nov 7, 2025
1ddbf7f
done
david-wei-01001 Nov 7, 2025
c646ee5
done
david-wei-01001 Nov 7, 2025
7d5fe2a
Rename hubert_ctc_test.py to demos/HuBERT_test/hubert_ctc_test.py
david-wei-01001 Nov 7, 2025
21a0256
Rename hubert_hook_test.py to demos/HuBERT_test /hubert_hook_test.py
david-wei-01001 Nov 7, 2025
c9f7c68
Rename hubert_hook_test.py to hubert_hook_test.py
david-wei-01001 Nov 7, 2025
2f578ce
Rename hubert_test.py to demos/HuBERT_test/hubert_test.py
david-wei-01001 Nov 7, 2025
f76c2ee
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
69345b1
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
7be3d4e
Update hubert.py
david-wei-01001 Nov 7, 2025
7e177c4
Update hubert_ctc_test.py
david-wei-01001 Nov 7, 2025
6737ccd
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
e062f38
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
340260f
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
3c44076
Update loading_from_pretrained.py
david-wei-01001 Nov 7, 2025
64aeb4c
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
71a4f51
Update hubert.py
david-wei-01001 Nov 7, 2025
f0207ca
Update hubert_ctc_test.py
david-wei-01001 Nov 7, 2025
98f6eac
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
da84180
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
ede04f8
Update hubert_test.py
david-wei-01001 Nov 7, 2025
305509a
Update loading_from_pretrained.py
david-wei-01001 Nov 7, 2025
6461e2e
Update hubert.py
david-wei-01001 Nov 7, 2025
5344612
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
32db5d2
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
560ffb9
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
dda10e5
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
219defb
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
2df2d27
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
46c3344
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
6272b9f
Update loading_from_pretrained.py
david-wei-01001 Nov 7, 2025
af6163d
Update loading_from_pretrained.py
david-wei-01001 Nov 7, 2025
0b5a860
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
817c97f
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
48920e1
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
6dcffb2
Update hubert_hook_test.py
david-wei-01001 Nov 7, 2025
5f7af85
Update HookedAudioEncoder.py
david-wei-01001 Nov 7, 2025
fefcea2
Update HookedAudioEncoder.py
david-wei-01001 Nov 9, 2025
b1414e0
Update HookedAudioEncoder.py
david-wei-01001 Nov 9, 2025
94bd3d7
Update HookedAudioEncoder.py
david-wei-01001 Nov 9, 2025
fbae9c1
Update HookedAudioEncoder.py
david-wei-01001 Nov 11, 2025
f23d0d9
Update HookedAudioEncoder.py
david-wei-01001 Nov 11, 2025
d20ee07
Update HookedAudioEncoder.py
david-wei-01001 Nov 12, 2025
00c12cb
Update HookedAudioEncoder.py
david-wei-01001 Nov 12, 2025
14ab5bb
Update HookedAudioEncoder.py
david-wei-01001 Nov 17, 2025
41402ba
Update loading_from_pretrained.py
david-wei-01001 Nov 17, 2025
b5cb2e1
Update loading_from_pretrained.py
david-wei-01001 Nov 17, 2025
f8200bc
Update loading_from_pretrained.py
david-wei-01001 Nov 17, 2025
6926e2b
Update loading_from_pretrained.py
david-wei-01001 Nov 17, 2025
e8e958c
Update loading_from_pretrained.py
david-wei-01001 Nov 17, 2025
fa89321
Update requirements.txt
david-wei-01001 Nov 17, 2025
5a7c5c7
Update requirements.txt
david-wei-01001 Nov 17, 2025
cd8e922
Update loading_from_pretrained.py
david-wei-01001 Nov 18, 2025
77285ba
Update loading_from_pretrained.py
david-wei-01001 Nov 22, 2025
9fa6464
Update loading_from_pretrained.py
david-wei-01001 Nov 24, 2025
fc9327e
Update HookedAudioEncoder.py
david-wei-01001 Nov 24, 2025
c6a43a7
Update bert_pooler.py
david-wei-01001 Nov 24, 2025
9427068
Update bert_pooler.py
david-wei-01001 Nov 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions demos/HuBERT_test/hubert_ctc_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# test_hubert_ctc_lmhead.py
"""
Test script to verify HookedAudioEncoder.forward(..., use_ctc=True)
loads/uses an lm_head and produces CTC logits.

Usage:
python test_hubert_ctc_lmhead.py
Change the import to point at your HookedAudioEncoder implementation.
"""

import math

import numpy as np
import torch

from transformer_lens import HookedAudioEncoder

# ----- CONFIG -----
SAMPLE_RATE = 16000
DURATION_S = 1.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 1
# If you want to attempt optional decoding with a HF tokenizer,
# set TOKENIZER_NAME to a valid tokenizer (e.g. "facebook/wav2vec2-base-960h")
# or set to None to skip tokenizer decoding.
TOKENIZER_NAME = "facebook/hubert-large-ls960-ft"
# ------------------

def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1):
t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32)
return amplitude * np.sin(2 * math.pi * frequency * t)

def has_lm_head(model):
return any(name.endswith("lm_head") or name == "lm_head" for name, _ in model.named_children()) or hasattr(model, "lm_head")

def try_get_lm_head(model):
if hasattr(model, "lm_head"):
return model.lm_head
# try common nested names
for name, module in model.named_modules():
if name.endswith("lm_head") or name == "lm_head":
return module
return None

def print_param_info(module, prefix=""):
if module is None:
print(prefix + "None")
return
params = list(module.parameters())
print(prefix + f"module type: {type(module)}, #params: {sum(p.numel() for p in params)}")
# print weight shape if available
if hasattr(module, "weight"):
try:
print(prefix + f" weight.shape = {tuple(module.weight.shape)}")
except Exception:
pass

if __name__ == "__main__":
model = HookedAudioEncoder.from_pretrained("facebook/hubert-large-ls960-ft")

model.to(DEVICE)

# sample waveform
wav = make_sine(frequency=440.0)
x = torch.from_numpy(wav).unsqueeze(0).to(DEVICE) # shape (1, T)

print("=== lm_head presence BEFORE forward() ===")
print("has_lm_head():", has_lm_head(model))
print("try_get_lm_head():")
print_param_info(try_get_lm_head(model), prefix=" ")

# Forward pass with use_ctc=True (some model APIs accept it directly, some do not).
print("\nCalling forward(..., use_ctc=True) -- if that fails, will set attribute and call without arg")
logits = None
forward_exc = None
try:
# try direct call with argument
out = model(x, use_ctc=True)
except TypeError as e:
# forward signature may not accept use_ctc param; try setting attribute on model and call
forward_exc = e
print("Direct forward(..., use_ctc=True) failed with TypeError - will try setting model.use_ctc = True and calling forward(x).")
try:
if hasattr(model, "use_ctc"):
model.use_ctc = True
else:
# set attribute anyway
setattr(model, "use_ctc", True)
out = model(x)
except Exception as e2:
print("Forward still failed after setting model.use_ctc =", e2)
raise

# Normalize out to logits tensor if possible
def extract_logits(out):
if out is None:
return None
if isinstance(out, torch.Tensor):
return out # assume logits
# dict-like outputs: look for common keys
if isinstance(out, dict):
for key in ("logits", "ctc_logits", "predictions", "hidden_states"):
if key in out:
t = out[key]
# if hidden_states is (batch, seq, dim) that's also fine to inspect
if isinstance(t, torch.Tensor):
return t
# if no known keys found, try to pick first tensor value
for v in out.values():
if isinstance(v, torch.Tensor):
return v
# fallback: try to convert
return None

logits = extract_logits(out)
print("\n=== Post-forward lm_head presence ===")
print("has_lm_head():", has_lm_head(model))
lm = try_get_lm_head(model)
print("try_get_lm_head():")
print_param_info(lm, prefix=" ")

if logits is None:
print("\nCould not automatically extract logits from the model output. The model returned:", type(out))
# if out is tensor-like but not torch tensor, attempt conversion
if hasattr(out, "numpy"):
try:
logits = torch.from_numpy(out.numpy()).to(DEVICE)
except Exception:
pass

if logits is not None:
print("\n=== Logits / CTC output info ===")
print("logits type:", type(logits))
print("logits shape:", tuple(logits.shape))
# typical CTC logits shape: (batch, time, vocab_size) or (batch, seq_len, vocab)
try:
print("stats: min=%.6g max=%.6g mean=%.6g" % (logits.min().item(), logits.max().item(), logits.mean().item()))
except Exception:
pass
assert torch.isfinite(logits).all(), "Found NaNs/Infs in logits!"

# simple decode: argmax over last dim -> token ids
if logits.ndim >= 2:
token_dim = -1
token_ids = logits.argmax(dim=token_dim) # shape: (batch, time)
token_ids_cpu = token_ids.detach().cpu().numpy()
print("Sample argmax token ids (first batch, up to first 40 frames):")
print(token_ids_cpu[0][:40].tolist())

# Optional: try to decode token ids to text if a tokenizer is available
if TOKENIZER_NAME is not None:
try:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
# For many CTC tokenizers, you need to collapse repeats and remove blank token id (often id=0 or tok.pad_token_id)
# Here we do a naive collapse+remove assuming blank token is tokenizer.pad_token_id or tokenizer.pad_token_id==tok.pad_token_id
blank_id = getattr(tok, "pad_token_id", None)
seq = token_ids_cpu[0].tolist()
# collapse repeats and remove blanks
collapsed = []
prev = None
for t in seq:
if t == prev:
prev = t
continue
prev = t
if blank_id is not None and t == blank_id:
continue
collapsed.append(t)
decoded = tok.decode(collapsed, skip_special_tokens=True)
print("Decoded (naive collapse) text:", decoded)
except Exception as e:
print("Optional decoding failed:", e)

else:
print("No logits found — cannot run CTC-specific checks.")

# Gradient test specifically for transformer encoder (since lm_head is frozen)
print("\nRunning gradient propagation test through transformer encoder...")

model.train()
for p in model.parameters():
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()

try:
out2 = model(x, use_ctc=True)
except TypeError:
if hasattr(model, "use_ctc"):
model.use_ctc = True
out2 = model(x)

logits2 = extract_logits(out2)
if logits2 is None:
print("Could not extract logits for gradient test; aborting gradient check.")
else:
loss = logits2.mean()
loss.backward()

# --- Check that lm_head is frozen ---
lm = try_get_lm_head(model)
if lm is not None:
lm_params = list(lm.parameters())
grads = [p.grad for p in lm_params if p.grad is not None]
if len(grads) > 0:
print("Warning: lm_head has gradients, but it should be frozen (eval mode).")
else:
print("✅ lm_head correctly frozen (no gradients).")

# --- Check that transformer block parameters have gradients ---
has_transformer_grad = False
for name, p in model.named_parameters():
if "transformer" in name or "encoder" in name or "block" in name:
print(name)
if p.grad is not None and torch.isfinite(p.grad).all():
has_transformer_grad = True
break

if has_transformer_grad:
print("✅ Gradient test PASSED: transformer block parameters have finite gradients.")
else:
print("❌ Gradient test FAILED: no gradients found in transformer blocks.")


print("\n=== DONE ===")
print("Interpretation notes:")
print(" - If lm_head appears AFTER calling forward(use_ctc=True) and logits shape looks like (B, T, V),")
print(" then your forward-path is constructing/attaching an lm_head and producing CTC logits.")
print(" - If lm_head parameters have finite gradients after loss.backward(), the head is hooked into the graph.")
print(" - If you want a numeric golden-check, instantiate a HF Hubert/Wav2Vec2 CTC model and compare pooled logits/ids (optional).")
print(model.named_parameters())
Loading
Loading