diff --git a/demos/HuBERT_test/hubert_ctc_test.py b/demos/HuBERT_test/hubert_ctc_test.py new file mode 100644 index 000000000..d5ba8f46c --- /dev/null +++ b/demos/HuBERT_test/hubert_ctc_test.py @@ -0,0 +1,232 @@ +# test_hubert_ctc_lmhead.py +""" +Test script to verify HookedAudioEncoder.forward(..., use_ctc=True) +loads/uses an lm_head and produces CTC logits. + +Usage: + python test_hubert_ctc_lmhead.py +Change the import to point at your HookedAudioEncoder implementation. +""" + +import math + +import numpy as np +import torch + +from transformer_lens import HookedAudioEncoder + +# ----- CONFIG ----- +SAMPLE_RATE = 16000 +DURATION_S = 1.0 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +BATCH_SIZE = 1 +# If you want to attempt optional decoding with a HF tokenizer, +# set TOKENIZER_NAME to a valid tokenizer (e.g. "facebook/wav2vec2-base-960h") +# or set to None to skip tokenizer decoding. +TOKENIZER_NAME = "facebook/hubert-large-ls960-ft" +# ------------------ + +def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1): + t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32) + return amplitude * np.sin(2 * math.pi * frequency * t) + +def has_lm_head(model): + return any(name.endswith("lm_head") or name == "lm_head" for name, _ in model.named_children()) or hasattr(model, "lm_head") + +def try_get_lm_head(model): + if hasattr(model, "lm_head"): + return model.lm_head + # try common nested names + for name, module in model.named_modules(): + if name.endswith("lm_head") or name == "lm_head": + return module + return None + +def print_param_info(module, prefix=""): + if module is None: + print(prefix + "None") + return + params = list(module.parameters()) + print(prefix + f"module type: {type(module)}, #params: {sum(p.numel() for p in params)}") + # print weight shape if available + if hasattr(module, "weight"): + try: + print(prefix + f" weight.shape = {tuple(module.weight.shape)}") + except Exception: + pass + +if __name__ == "__main__": + model = HookedAudioEncoder.from_pretrained("facebook/hubert-large-ls960-ft") + + model.to(DEVICE) + + # sample waveform + wav = make_sine(frequency=440.0) + x = torch.from_numpy(wav).unsqueeze(0).to(DEVICE) # shape (1, T) + + print("=== lm_head presence BEFORE forward() ===") + print("has_lm_head():", has_lm_head(model)) + print("try_get_lm_head():") + print_param_info(try_get_lm_head(model), prefix=" ") + + # Forward pass with use_ctc=True (some model APIs accept it directly, some do not). + print("\nCalling forward(..., use_ctc=True) -- if that fails, will set attribute and call without arg") + logits = None + forward_exc = None + try: + # try direct call with argument + out = model(x, use_ctc=True) + except TypeError as e: + # forward signature may not accept use_ctc param; try setting attribute on model and call + forward_exc = e + print("Direct forward(..., use_ctc=True) failed with TypeError - will try setting model.use_ctc = True and calling forward(x).") + try: + if hasattr(model, "use_ctc"): + model.use_ctc = True + else: + # set attribute anyway + setattr(model, "use_ctc", True) + out = model(x) + except Exception as e2: + print("Forward still failed after setting model.use_ctc =", e2) + raise + + # Normalize out to logits tensor if possible + def extract_logits(out): + if out is None: + return None + if isinstance(out, torch.Tensor): + return out # assume logits + # dict-like outputs: look for common keys + if isinstance(out, dict): + for key in ("logits", "ctc_logits", "predictions", "hidden_states"): + if key in out: + t = out[key] + # if hidden_states is (batch, seq, dim) that's also fine to inspect + if isinstance(t, torch.Tensor): + return t + # if no known keys found, try to pick first tensor value + for v in out.values(): + if isinstance(v, torch.Tensor): + return v + # fallback: try to convert + return None + + logits = extract_logits(out) + print("\n=== Post-forward lm_head presence ===") + print("has_lm_head():", has_lm_head(model)) + lm = try_get_lm_head(model) + print("try_get_lm_head():") + print_param_info(lm, prefix=" ") + + if logits is None: + print("\nCould not automatically extract logits from the model output. The model returned:", type(out)) + # if out is tensor-like but not torch tensor, attempt conversion + if hasattr(out, "numpy"): + try: + logits = torch.from_numpy(out.numpy()).to(DEVICE) + except Exception: + pass + + if logits is not None: + print("\n=== Logits / CTC output info ===") + print("logits type:", type(logits)) + print("logits shape:", tuple(logits.shape)) + # typical CTC logits shape: (batch, time, vocab_size) or (batch, seq_len, vocab) + try: + print("stats: min=%.6g max=%.6g mean=%.6g" % (logits.min().item(), logits.max().item(), logits.mean().item())) + except Exception: + pass + assert torch.isfinite(logits).all(), "Found NaNs/Infs in logits!" + + # simple decode: argmax over last dim -> token ids + if logits.ndim >= 2: + token_dim = -1 + token_ids = logits.argmax(dim=token_dim) # shape: (batch, time) + token_ids_cpu = token_ids.detach().cpu().numpy() + print("Sample argmax token ids (first batch, up to first 40 frames):") + print(token_ids_cpu[0][:40].tolist()) + + # Optional: try to decode token ids to text if a tokenizer is available + if TOKENIZER_NAME is not None: + try: + from transformers import AutoTokenizer + tok = AutoTokenizer.from_pretrained(TOKENIZER_NAME) + # For many CTC tokenizers, you need to collapse repeats and remove blank token id (often id=0 or tok.pad_token_id) + # Here we do a naive collapse+remove assuming blank token is tokenizer.pad_token_id or tokenizer.pad_token_id==tok.pad_token_id + blank_id = getattr(tok, "pad_token_id", None) + seq = token_ids_cpu[0].tolist() + # collapse repeats and remove blanks + collapsed = [] + prev = None + for t in seq: + if t == prev: + prev = t + continue + prev = t + if blank_id is not None and t == blank_id: + continue + collapsed.append(t) + decoded = tok.decode(collapsed, skip_special_tokens=True) + print("Decoded (naive collapse) text:", decoded) + except Exception as e: + print("Optional decoding failed:", e) + + else: + print("No logits found — cannot run CTC-specific checks.") + + # Gradient test specifically for transformer encoder (since lm_head is frozen) + print("\nRunning gradient propagation test through transformer encoder...") + + model.train() + for p in model.parameters(): + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + try: + out2 = model(x, use_ctc=True) + except TypeError: + if hasattr(model, "use_ctc"): + model.use_ctc = True + out2 = model(x) + + logits2 = extract_logits(out2) + if logits2 is None: + print("Could not extract logits for gradient test; aborting gradient check.") + else: + loss = logits2.mean() + loss.backward() + + # --- Check that lm_head is frozen --- + lm = try_get_lm_head(model) + if lm is not None: + lm_params = list(lm.parameters()) + grads = [p.grad for p in lm_params if p.grad is not None] + if len(grads) > 0: + print("Warning: lm_head has gradients, but it should be frozen (eval mode).") + else: + print("✅ lm_head correctly frozen (no gradients).") + + # --- Check that transformer block parameters have gradients --- + has_transformer_grad = False + for name, p in model.named_parameters(): + if "transformer" in name or "encoder" in name or "block" in name: + print(name) + if p.grad is not None and torch.isfinite(p.grad).all(): + has_transformer_grad = True + break + + if has_transformer_grad: + print("✅ Gradient test PASSED: transformer block parameters have finite gradients.") + else: + print("❌ Gradient test FAILED: no gradients found in transformer blocks.") + + + print("\n=== DONE ===") + print("Interpretation notes:") + print(" - If lm_head appears AFTER calling forward(use_ctc=True) and logits shape looks like (B, T, V),") + print(" then your forward-path is constructing/attaching an lm_head and producing CTC logits.") + print(" - If lm_head parameters have finite gradients after loss.backward(), the head is hooked into the graph.") + print(" - If you want a numeric golden-check, instantiate a HF Hubert/Wav2Vec2 CTC model and compare pooled logits/ids (optional).") + print(model.named_parameters()) diff --git a/demos/HuBERT_test/hubert_hook_test.py b/demos/HuBERT_test/hubert_hook_test.py new file mode 100644 index 000000000..79225bdc0 --- /dev/null +++ b/demos/HuBERT_test/hubert_hook_test.py @@ -0,0 +1,180 @@ +import math + +import numpy as np +import torch + +import transformer_lens.utils as utils +from transformer_lens import HookedAudioEncoder + +# ---- Simple sine audio generator ---- +SAMPLE_RATE = 16000 +DURATION_S = 1.0 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + +def make_sine(sr=SAMPLE_RATE, duration=DURATION_S, freq=440.0, amp=0.1): + t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32) + return amp * np.sin(2 * math.pi * freq * t) + +audio_model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960", device="cuda") + +def main(): + # --- Build a 1s test waveform --- + wav = make_sine() + # If to_frames expects numpy or torch, both are accepted by your implementation + raw_batch = [wav] # batch of one + + # --- Convert to frames using your helper (you provided to_frames) --- + # IMPORTANT: use the same sampling_rate you used during training/FT (16k typical) + try: + frames, frame_mask = audio_model.to_frames(raw_batch, sampling_rate=SAMPLE_RATE, move_to_device=True) + except NameError: + raise RuntimeError("Replace `audio_model` with your model/wrapper instance that implements to_frames().") + + # frames shape expected: (batch, frames, hidden) ; frame_mask: (batch, frames) (1/0) + print("frames.shape:", tuple(frames.shape)) + if frame_mask is not None: + print("frame_mask.shape:", tuple(frame_mask.shape)) + + # --- Run with cache to inspect attention pattern --- + # remove_batch_dim=True makes cached activations shaped like (pos, ...) for easier visualization (like LLaMA example) + logits, cache = audio_model.run_with_cache(frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True) + + # Picking a layer and head for visualization + layer_to_visualize = 0 + # act name for attention pattern — this is the same helper you used earlier + pattern_name = utils.get_act_name("pattern", layer_to_visualize) # e.g. "pattern_0" depending on utils + # some implementations store pattern as (layer, "attn") tuple; utils.get_act_name helps avoid mistakes + + # Extract attention pattern. Adapt this extraction if your cache key structure differs: + try: + attention_pattern = cache[pattern_name] # expected shape: (pos, pos, n_heads) or (pos, n_heads, pos) depending on implementation + except Exception: + # fallback: try tuple-key style + try: + attention_pattern = cache["pattern", layer_to_visualize, "attn"] + except Exception as exc: + raise RuntimeError(f"Couldn't find attention pattern in cache. Keys: {list(cache.keys())}") from exc + + # Build human-friendly "tokens" for frames (e.g. frame indices as strings) + n_frames = attention_pattern.shape[0] + frame_tokens = [f"f{i}" for i in range(n_frames)] + + print("Layer", layer_to_visualize, "attention pattern shape:", tuple(attention_pattern.shape)) + print("Displaying attention patterns (layer", layer_to_visualize, ")") + # display(cv.attention.attention_patterns(tokens=frame_tokens, attention=attention_pattern)) + + # --- Define a head ablation hook (zero out a given head's v output) --- + head_index_to_ablate = 0 + layer_to_ablate = 0 + + # Hook target: v (value output) or "pattern" depending on what you'd like to ablate. + # Using the 'v' activation is a common choice, same form as your LLaMA example. + v_act_name = utils.get_act_name("v", layer_to_ablate) + + def head_ablation_hook(value, hook): + """ + value expected shape: [batch pos head d_head] OR [pos head d_head] when remove_batch_dim=True + We'll allow both shapes. + """ + # convert to mutable clone (some frameworks give non-writable tensors) + v = value.clone() + if v.ndim == 4: + # (B, pos, heads, d) + v[:, :, head_index_to_ablate, :] = 0.0 + elif v.ndim == 3: + # (pos, heads, d) + v[:, head_index_to_ablate, :] = 0.0 + else: + raise RuntimeError(f"Unexpected v tensor ndim={v.ndim}") + return v + + # --- Compute a downstream quantity without ablation --- + # Choose a metric you care about. Good choices: + # - CTC logits (if using use_ctc=True) -> argmax tokens or loss + # - Pooled encoder representation (mean of final resid_post) -> cosine similarity + # We'll implement both: try to extract CTC logits from model output; if not found, use pooled resid_post. + + def run_and_get_repr(frames, frame_mask, hooks=None): + # hooks: list of (act_name, hook_fn) tuples for run_with_hooks + if hooks is None: + # run_with_cache to gather activations + cache = audio_model.run_with_cache(frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True) + out = audio_model.run_with_hooks(frames, fwd_hooks=[]) + # NOTE: if your API returns outputs directly from run_with_cache, adapt as needed. + else: + # run with hooks and also capture cache + # run_with_hooks typically returns output (or logits) and optionally a cache depending on your implementation + out = audio_model.run_with_hooks(frames, fwd_hooks=hooks, one_zero_attention_mask=frame_mask) + # If return_type="both" isn't supported, you can run run_with_cache and run_with_hooks separately. + # Try to extract CTC logits from `out` first + logits = None + if isinstance(out, dict): + for k in ("logits", "ctc_logits", "logits_ctc", "predictions"): + if k in out and isinstance(out[k], torch.Tensor): + logits = out[k] + break + elif isinstance(out, torch.Tensor): + # ambiguous: could be embeddings or logits + logits = out + + # if logits exist -> pooled logits (mean over time) as representation + if logits is not None: + # ensure shape (batch, time, vocab) -> pool over time axis (1) + if logits.ndim == 3: + pooled = logits.mean(dim=1) # (batch, vocab) + elif logits.ndim == 2: + pooled = logits # maybe (batch, vocab) + else: + pooled = logits.view(logits.shape[0], -1).mean(dim=1, keepdim=True) + return pooled, logits, None # third slot reserved for cache + + # fallback: use final residual activation from cache (resid_post of last layer) + try: + last_layer = audio_model.cfg.n_layers - 1 + resid_name = utils.get_act_name("resid_post", last_layer) + # get cache from run_with_cache (we ran above) + cache = audio_model.run_with_cache(frames, one_zero_attention_mask=frame_mask, remove_batch_dim=True) + resid = cache[resid_name] # e.g. (pos, d) or (batch,pos,d) + # mean-pool across pos dimension + if resid.ndim == 3: + pooled = resid.mean(dim=1) # (batch, d) + elif resid.ndim == 2: + pooled = resid.mean(dim=0, keepdim=True) + else: + raise RuntimeError("Unexpected resid_post shape") + return pooled, None, cache + except Exception as e: + raise RuntimeError("Couldn't extract logits or resid_post; adapt the extraction to your model's output format.") from e + + # Get baseline representation + baseline_repr, baseline_logits, baseline_cache = run_and_get_repr(frames, frame_mask, hooks=None) + print("Baseline representation shape:", tuple(baseline_repr.shape)) + + # --- Run with ablation hook and get representation --- + hooks = [(v_act_name, head_ablation_hook)] + ablated_repr, ablated_logits, ablated_cache = run_and_get_repr(frames, frame_mask, hooks=hooks) + print("Ablated representation shape:", tuple(ablated_repr.shape)) + + # --- Compare representations (cosine similarity) --- + cos = torch.nn.functional.cosine_similarity(baseline_repr, ablated_repr, dim=-1) + print("Cosine similarity baseline vs ablated:", cos.detach().cpu().numpy()) + + # If you have logits, you can also compare token sequences (argmax) or loss increase + if baseline_logits is not None and ablated_logits is not None: + b_ids = baseline_logits.argmax(dim=-1) # (batch, time) + a_ids = ablated_logits.argmax(dim=-1) + print("Sample argmax token ids (baseline):", b_ids[0][:40].cpu().numpy().tolist()) + print("Sample argmax token ids (ablated): ", a_ids[0][:40].cpu().numpy().tolist()) + + print("Done. Interpret the results:") + print(" - A large drop in cosine similarity (or large change in argmax tokens / increase in loss) means the ablated head mattered.") + print(" - If ablation causes little change, that head may be redundant or not used for this example.") + +if __name__ == "__main__": + # create/instantiate your model here: replace the placeholder below + # Example: + # audio_model = HookedAudioEncoder.from_pretrained("...").to(DEVICE) + # audio_model.cfg.device = DEVICE + # For wrapper that exposes to_frames: + # audio_model = YourWrapperClass(...) + main() diff --git a/demos/HuBERT_test/hubert_test.py b/demos/HuBERT_test/hubert_test.py new file mode 100644 index 000000000..55bd8f5cc --- /dev/null +++ b/demos/HuBERT_test/hubert_test.py @@ -0,0 +1,132 @@ +# test_hubert_hooked.py +import math + +import numpy as np +import torch + +from transformer_lens import HookedAudioEncoder + +# ---------- CONFIG ---------- +SAMPLE_RATE = 16000 +DURATION_S = 1.0 +BATCH_SIZE = 1 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +# Name of HF checkpoint to use if you want to compare outputs (optional) +HF_CHECKPOINT = "facebook/hubert-base-ls960" # optional +# ---------------------------- + +def make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S, amplitude=0.1): + t = np.linspace(0, duration, int(sr*duration), endpoint=False, dtype=np.float32) + wav = amplitude * np.sin(2 * math.pi * frequency * t) + return wav + +def run_basic_sanity_tests(model, waveform_np): + """Run quick checks: forward pass, shape, finite, deterministic, grad flow.""" + model.to(DEVICE) + + # Prepare tensor: shape (batch, time) + x = torch.from_numpy(waveform_np).unsqueeze(0).to(DEVICE) # (1, T) + + # 1) Eval forward: no grad + model.eval() + with torch.no_grad(): + out1 = model(x) # adapt if your API uses return_type="predictions" or similar + print("Forward (eval) output type:", type(out1)) + try: + out_tensor = out1 if isinstance(out1, torch.Tensor) else out1["predictions"] + except Exception: + out_tensor = out1 # fallback + + print("Output shape:", tuple(out_tensor.shape)) + print("Output stats: min=%.6g max=%.6g mean=%.6g" % (out_tensor.min().item(), out_tensor.max().item(), out_tensor.mean().item())) + assert torch.isfinite(out_tensor).all(), "Found NaNs or Infs in forward output!" + + # 2) Determinism in eval + with torch.no_grad(): + out2 = model(x) + # if model returns dict-like, extract tensor again + out2_tensor = out2 if isinstance(out2, torch.Tensor) else out2["predictions"] + if not torch.allclose(out_tensor, out2_tensor, atol=1e-6): + print("Warning: outputs differ between two eval runs (non-deterministic?), max diff:", (out_tensor - out2_tensor).abs().max().item()) + else: + print("Determinism test passed (eval mode).") + + # 3) Gradient flow test in train mode + model.train() + # zero grads + for p in model.parameters(): + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + out_train = model(x) + out_train_tensor = out_train if isinstance(out_train, torch.Tensor) else out_train["predictions"] + + # small scalar loss + loss = out_train_tensor.mean() + loss.backward() + # check some parameters got gradients + grads_found = any((p.grad is not None and torch.isfinite(p.grad).all()) for p in model.parameters() if p.requires_grad) + assert grads_found, "No finite gradients found on any parameter after backward()" + print("Gradient check passed: some parameters have finite gradients.") + +def optional_compare_to_hf(your_model, waveform_np, sr=SAMPLE_RATE): + """ + OPTIONAL: compare your_model outputs to Hugging Face's HubertModel outputs. + This requires transformers to be installed and internet access to download the checkpoint. + Important: to get a meaningful comparison you must match *exact preprocessing* (resampling, + normalization, padding/truncation) that the HF model expects and that your model used. + """ + try: + from transformers import HubertModel, Wav2Vec2FeatureExtractor + except Exception as e: + print("Transformers or feature extractor not available:", e) + return + + print("Loading Hugging Face HubertModel for optional comparison (may take a while)...") + hf_feat = Wav2Vec2FeatureExtractor(sampling_rate=sr, do_normalize=True) + hf_model = HubertModel.from_pretrained(HF_CHECKPOINT).to(DEVICE).eval() + + # Prepare input for HF model + input_values = hf_feat(waveform_np, sampling_rate=sr, return_tensors="pt").get("input_values") # (1, T) + input_values = input_values.to(DEVICE) + + with torch.no_grad(): + hf_outputs = hf_model(input_values).last_hidden_state # (1, L, D) + # Pool HF tokens to a single vector (simple mean pooling) + hf_embedding = hf_outputs.mean(dim=1) # (1, D) + + # Get your model's representation and pool similarly + your_model.eval() + with torch.no_grad(): + your_out = your_model(torch.from_numpy(waveform_np).unsqueeze(0).to(DEVICE)) + your_tensor = your_out if isinstance(your_out, torch.Tensor) else your_out["predictions"] # shape depends on your model + # If your output has time dimension, mean-pool across time + if your_tensor.ndim == 3: + your_emb = your_tensor.mean(dim=1) + else: + your_emb = your_tensor # assume (1, D) or similar + + # Resize / project if dims differ (simple check) + if hf_embedding.shape[1] != your_emb.shape[1]: + print(f"Dimension mismatch (HF {hf_embedding.shape[1]} vs your {your_emb.shape[1]}). " + "You can compare after projecting to a common dim (not shown).") + return + + # Cosine similarity + cos = torch.nn.functional.cosine_similarity(hf_embedding, your_emb, dim=1) + print("Cosine similarity between HF pooled embedding and your model:", cos.cpu().numpy()) + +if __name__ == "__main__": + # Create sample waveform + wav = make_sine(frequency=440.0, sr=SAMPLE_RATE, duration=DURATION_S) + + # ----------------------- + # Instantiate your model + # ----------------------- + # Example 1: from_pretrained API (if you implemented it) + model = HookedAudioEncoder.from_pretrained("facebook/hubert-base-ls960").to(DEVICE) + # Run tests + run_basic_sanity_tests(model, wav) + + # Optionally compare to HF (network required) + optional_compare_to_hf(model, wav, sr=SAMPLE_RATE) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..f90bac6e9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +rich +torch==2.9.0 +transformers +datasets +jaxtyping +datasets<3.0.0 +einops +better_abc +typeguard +wandb +circuitsvis diff --git a/transformer_lens/HookedAudioEncoder.py b/transformer_lens/HookedAudioEncoder.py new file mode 100644 index 000000000..5d7f5a0a4 --- /dev/null +++ b/transformer_lens/HookedAudioEncoder.py @@ -0,0 +1,547 @@ +"""Hooked Audio Encoder. + +Contains a HuBERT style model. This is separate from :class:`transformer_lens.HookedTransformer` +because it has a significantly different architecture to e.g. GPT style transformers. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload + +from einops import repeat +from jaxtyping import Float, Int +import numpy as np +import torch +import torch.nn as nn +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + HubertForCTC, + HubertModel, + Wav2Vec2Model +) +from typing_extensions import Literal + +from transformer_lens import loading_from_pretrained as loading +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.components import ( + Attention, + BertBlock, + MLP, +) +from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.hook_points import HookedRootModule +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig +from transformer_lens.utilities import devices + + +T = TypeVar("T", bound="HookedEncoder") + + +class HookedAudioEncoder(HookedRootModule): + """ + This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + """ + + def __init__( + self, + cfg: Union[HookedTransformerConfig, Dict], + move_to_device: bool = True, + model_name: str = "facebook/hubert-base-ls960", + use_ctc: bool = False, + **kwargs: Any, + ): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError( + "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." + ) + self.cfg = cfg + + assert self.cfg.n_devices == 1, "Multiple devices not supported for HookedEncoder" + + self.blocks = nn.ModuleList([BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]) + + if model_name.endswith("-ft") and use_ctc: + # fine-tuned model (has CTC head) + use_ctc = True + processor = AutoProcessor.from_pretrained(model_name) # builds input_values + attention_mask + else: + # pretraining-only model (no CTC) + use_ctc = False + processor = AutoFeatureExtractor.from_pretrained(model_name) + + if model_name.endswith("-ft") and use_ctc: + hubert_model = HubertForCTC.from_pretrained(model_name) + elif "wav2vec2" in model_name: + hubert_model = Wav2Vec2Model.from_pretrained(model_name) + else: + hubert_model = HubertModel.from_pretrained(model_name) + + if move_to_device: + if self.cfg.device is None: + raise ValueError("Cannot move to device when device is None") + hubert_model.to(self.cfg.device) + + hubert_model.eval() + self.processor = processor + if use_ctc: + self.hubert_model = hubert_model.hubert + self.lm_head = hubert_model.lm_head + for p in self.lm_head.parameters(): + p.requires_grad = False + else: + self.hubert_model = hubert_model + self.lm_head = None + + if move_to_device: + if self.cfg.device is None: + raise ValueError("Cannot move to device when device is None") + self.to(self.cfg.device) + + self.setup() + + def _ensure_numpy(self, wave): + """ + Convert torch.Tensor / np.ndarray / list -> 1D np.float32 array on CPU. + """ + if isinstance(wave, torch.Tensor): + arr = wave.detach().cpu().numpy() + elif isinstance(wave, np.ndarray): + arr = wave + elif isinstance(wave, list): + arr = np.asarray(wave) + else: + raise TypeError("wave must be torch.Tensor, np.ndarray or list of floats") + + # force 1-D (if stereo or shape (N,1) etc) + if arr.ndim > 1: + # if shape (n_samples, n_channels) average channels -> mono + if arr.shape[1] <= arr.shape[0]: + arr = arr.mean(axis=1) + else: + arr = arr.reshape(-1) + + return arr.astype(np.float32, copy=False) + + + def to_frames( + self, + raw_inputs: Union[torch.Tensor, List[torch.Tensor], List[np.ndarray]], + sampling_rate: int = 16000, + move_to_device: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert raw audio batch -> (projected frames, frame_attention_mask) + + Args: + raw_inputs: one of: + - a 1D torch.Tensor or numpy array (single waveform) + - a list of 1D torch.Tensors / numpy arrays (batch) + self.processor: HF AutoProcessor (creates input_values + sample-level attention_mask) + self.model: pretrained HubertModel (provides feature_extractor and feature_projection) + sampling_rate: sample rate of the audio (default 16k) + move_to_device: move outputs to model.device + + Returns: + frames: torch.Tensor of shape (batch, frames, hidden_size) <- after feature_projection + frame_attention_mask: torch.LongTensor of shape (batch, frames) with 1 for real frames, 0 for padding + """ + # AutoFeatureExtractor works better onnumpy array where it pads automatically. If passing in tensors, it does not pad properly, giving inhomogeneous arts error + if isinstance(raw_inputs, (torch.Tensor, np.ndarray)): + waves = [self._ensure_numpy(raw_inputs)] + elif isinstance(raw_inputs, list): + waves = [self._ensure_numpy(w) for w in raw_inputs] + else: + raise TypeError("Unsupported raw_inputs type") + + # Use HF processor to create input_values (padded) + sample-level attention_mask + # Processor will do padding so we can pass a variable-length batch + proc_out = self.processor(waves, sampling_rate=sampling_rate, return_tensors="pt", padding=True, return_attention_mask=True) + input_values = proc_out["input_values"] # (batch, samples), float + sample_attention_mask = proc_out.get("attention_mask") # (batch, samples), 1 for valid, 0 for padding; may be None + + # move to device + device = self.cfg.device + if move_to_device: + input_values = input_values.to(device) + if sample_attention_mask is not None: + sample_attention_mask = sample_attention_mask.to(device) + + # 1) convolutional frontend -> (batch, conv_dim, conv_time) + if input_values.ndim > 2: + input_values = input_values.squeeze() + if input_values.ndim == 1: + input_values = input_values.unsqueeze(0) # (1, T) + with torch.no_grad(): + conv_feats = self.hubert_model.feature_extractor(input_values) # (B, C, T_conv) + + # 2) transpose to (batch, T_conv, C) + extract_features = conv_feats.transpose(1, 2) + + # 3) compute reduced frame-level attention mask (if sample mask provided) + frame_attention_mask = None + if sample_attention_mask is not None: + # model should provide helper _get_feature_vector_attention_mask + try: + frame_attention_mask = self.hubert_model._get_feature_vector_attention_mask(extract_features.shape[1], sample_attention_mask) + except AttributeError: + # fallback: compute output lengths and create mask similarly to HF implementation + # compute output lengths (downsampled lengths) from sample attention mask (sums per example) + input_lengths = sample_attention_mask.sum(dim=-1) # (batch,) + # compute output lengths through conv layers using model._get_feat_extract_output_lengths if exists + if hasattr(model, "_get_feat_extract_output_lengths"): + output_lengths = self.hubert_model._get_feat_extract_output_lengths(input_lengths).to(torch.long) + else: + # fallback to naive downsample ratio: output_frames = extract_features.shape[1] + output_lengths = torch.full((sample_attention_mask.shape[0],), extract_features.shape[1], device=device, dtype=torch.long) + + batch_size = sample_attention_mask.shape[0] + feat_len = extract_features.shape[1] + frame_attention_mask = torch.zeros((batch_size, feat_len), dtype=sample_attention_mask.dtype, device=device) + # mark the last valid index for each example and then cumsum trick to fill ones before it + idx = (torch.arange(batch_size, device=device), (output_lengths - 1).clamp(min=0)) + frame_attention_mask[idx] = 1 + frame_attention_mask = frame_attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool().long() + + # 4) feature projection -> (batch, frames, hidden_size) + with torch.no_grad(): + hidden_states = self.hubert_model.feature_projection(extract_features) # typically returns (B, T, hidden) + # In HF's hubert, feature_projection is a module that returns a tensor (not tuple). If it returns tuple, adjust. + + # convert bool mask to long (1/0) if needed + if frame_attention_mask is not None: + frame_attention_mask = frame_attention_mask.to(dtype=torch.long) + + return hidden_states, frame_attention_mask + + def encoder_output( + self, + frames: torch.Tensor, # (batch, frames, d_model) <-- precomputed conv features + one_zero_attention_mask: Optional[torch.Tensor] = None, # (batch, frames) + ): + # Ensure device + if frames.device.type != self.cfg.device: + frames = frames.to(self.cfg.device) + if one_zero_attention_mask is not None: + one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) + + position_embeddings = self.hubert_model.encoder.pos_conv_embed(frames) + resid = frames + position_embeddings + resid = self.hubert_model.encoder.layer_norm(resid) + + large_negative_number = -torch.inf + mask = ( + repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") + if one_zero_attention_mask is not None + else None + ) + additive_attention_mask = ( + torch.where(mask == 1, large_negative_number, 0) if mask is not None else None + ) + for block in self.blocks: + resid = block(resid, additive_attention_mask) + + return resid + + def forward( + self, + inputs: Union[ + torch.Tensor, # waveform (1D) OR precomputed frames (3D) + List[Union[torch.Tensor, np.ndarray]], # list of waveforms + Tuple[torch.Tensor, torch.Tensor], # (frames, frame_mask) + ], + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + sampling_rate: int = 16000, + use_ctc: bool = False, + move_to_device: bool = True, + ) -> Optional[torch.Tensor]: + """ + HuBERT-like forward (Transformer-Lens style). + + Args: + input: one of: + - 1D torch.Tensor or numpy array (single waveform) OR list of 1D waveforms -> will call self.to_frames(...) + - 3D torch.Tensor shaped (batch, frames, d_model) -> treated as precomputed frames (skip to_frames) + - tuple (frames, frame_mask) -> use directly + sampling_rate: sampling rate for to_frames when converting raw audio. + use_proj: Whether to use the final head of HubertCTC + move_to_device: move tensors to self.cfg.device (to match your other code). + + Returns: + Depending on return_type: + - "hidden": (batch, frames, d_model) final encoder hidden states + """ + # ---------- 1) Normalize input: get (frames, frame_mask) ---------- + frames = None + frame_mask = None # one_zero_attention_mask: 1 = valid, 0 = padding + # print(type(inputs)) + # If user passed (frames, mask) tuple + if isinstance(inputs, tuple) and len(inputs) == 2 and isinstance(inputs[0], torch.Tensor): + frames, frame_mask = inputs + + # If user passed a 3D tensor -> assume (B, T, D) frames (pre-projected) + elif isinstance(inputs, torch.Tensor) and inputs.ndim == 3: + frames = inputs + # frame_mask stays whatever was passed as separate argument (None here) + + # Else treat as raw waveform(s) -> call to_frames + else: + # allow single 1D tensor or numpy array or list of tensors/arrays + frames, frame_mask = self.to_frames(inputs) + # to_frames should already place tensors on device if move_to_device=True + if isinstance(frames, tuple): + frames = frames[0] + frame_mask = frame_mask if one_zero_attention_mask is None else one_zero_attention_mask + # ---------- 2) Ensure device & dtype consistency ---------- + device = self.cfg.device + if frames.device.type != device: + frames = frames.to(device) + if frame_mask is not None: + frame_mask = frame_mask.to(device) + + # ---------- 3) Run encoder (respects pos_conv_embed / layer_norm / dropout inside encoder_output) ---------- + resid = self.encoder_output(frames, frame_mask) # (B, T, d_model) + + if use_ctc: + if self.lm_head is None: + logging.warning("HubertForCTC not enabled") + return resid + if isinstance(resid, tuple): + hidden_states = resid[0] # take last hidden state + else: + hidden_states = resid # already tensor + resid = self.lm_head(hidden_states) # (B, T, vocab_size) + + return resid + + @overload + def run_with_cache( + self, *model_args: Any, return_cache_object: Literal[True] = True, **kwargs: Any + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: + ... + + @overload + def run_with_cache( + self, *model_args: Any, return_cache_object: Literal[False], **kwargs: Any + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: + ... + + def run_with_cache( + self, + *model_args: Any, + return_cache_object: bool = True, + remove_batch_dim: bool = False, + **kwargs: Any, + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """ + Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. + """ + out, cache_dict = super().run_with_cache( + *model_args, remove_batch_dim=remove_batch_dim, **kwargs + ) + if return_cache_object: + cache = ActivationCache(cache_dict, self, has_batch_dim=not remove_batch_dim) + return out, cache + else: + return out, cache_dict + + def to( # type: ignore + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) + + def cuda(self: T, device: Optional[Union[int, torch.device]] = None) -> T: + if isinstance(device, int): + return self.to(f"cuda:{device}") + elif device is None: + return self.to("cuda") + else: + return self.to(device) + + def cpu(self: T) -> T: + return self.to("cpu") + + def mps(self: T) -> T: + return self.to(torch.device("mps")) + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, + hf_model: Optional[Any] = None, + device: Optional[str] = None, + move_to_device: bool = True, + dtype: torch.dtype = torch.float32, + use_ctc: bool = False, + **from_pretrained_kwargs: Any, + ) -> HookedEncoder: + """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" + logging.warning( + "Support for HuBERT in TransformerLens is currently experimental, until such a time when it has feature " + "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " + "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " + "implementation." + "\n" + "If using HuBERT for interpretability research, keep in mind that HuBERT has some significant architectural " + "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " + "that the last LayerNorm in a block cannot be folded." + ) + + assert not ( + from_pretrained_kwargs.get("load_in_8bit", False) + or from_pretrained_kwargs.get("load_in_4bit", False) + ), "Quantization not supported" + + if "torch_dtype" in from_pretrained_kwargs: + dtype = from_pretrained_kwargs["torch_dtype"] + + official_model_name = loading.get_official_model_name(model_name) + + if model_name.endswith("-ft") and use_ctc: + # fine-tuned model (has CTC head) + use_ctc = True + else: + # pretraining-only model (no CTC) + use_ctc = False + + cfg = loading.get_pretrained_model_config( + official_model_name, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=False, + device=device, + n_devices=1, + dtype=dtype, + **from_pretrained_kwargs, + ) + + state_dict = loading.get_pretrained_state_dict( + official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + ) + + model = cls(cfg, move_to_device=False, model_name=official_model_name, use_ctc=use_ctc) + + model.load_state_dict(state_dict, strict=False) + + if move_to_device: + model.to(cfg.device) + + print(f"Loaded pretrained model {model_name} into HookedEncoder") + + return model + + @property + def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the key weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_K for block in self.blocks], dim=0) + + @property + def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the query weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) + + @property + def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the value weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_V for block in self.blocks], dim=0) + + @property + def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: + """Stacks the attn output weights across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.W_O for block in self.blocks], dim=0) + + @property + def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: + """Stacks the MLP input weights across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) + + @property + def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: + """Stacks the MLP output weights across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) + + @property + def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the key biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_K for block in self.blocks], dim=0) + + @property + def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the query biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) + + @property + def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the value biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_V for block in self.blocks], dim=0) + + @property + def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the attn output biases across all layers""" + for block in self.blocks: + assert isinstance(block.attn, Attention) + return torch.stack([block.attn.b_O for block in self.blocks], dim=0) + + @property + def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: + """Stacks the MLP input biases across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) + + @property + def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the MLP output biases across all layers""" + for block in self.blocks: + assert isinstance(block.mlp, MLP) + return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) + + @property + def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. + Useful for visualizing attention patterns.""" + return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) + + @property + def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" + return FactoredMatrix(self.W_V, self.W_O) + + def all_head_labels(self) -> List[str]: + """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" + return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)] diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 1e2ff1e1a..7e6183c71 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -13,6 +13,7 @@ from .HookedTransformer import HookedTransformer from .SVDInterpreter import SVDInterpreter from .HookedEncoder import HookedEncoder +from .HookedAudioEncoder import HookedAudioEncoder from .HookedEncoderDecoder import HookedEncoderDecoder from .BertNextSentencePrediction import BertNextSentencePrediction from . import head_detector diff --git a/transformer_lens/components/bert_pooler.py b/transformer_lens/components/bert_pooler.py index cd205bf7f..4f23bba14 100644 --- a/transformer_lens/components/bert_pooler.py +++ b/transformer_lens/components/bert_pooler.py @@ -33,4 +33,5 @@ def forward( first_token_tensor = resid[:, 0] pooled_output = torch.matmul(first_token_tensor, self.W) + self.b pooled_output = self.hook_pooler_out(self.activation(pooled_output)) + # pooled_output = self.hook_pooler_out(pooled_output) return pooled_output diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8bfb6315d..c0c49d09e 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -18,7 +18,9 @@ AutoConfig, AutoModelForCausalLM, BertForPreTraining, + HubertModel, T5ForConditionalGeneration, + Wav2Vec2Model, ) import transformer_lens.utils as utils @@ -30,6 +32,7 @@ convert_gemma_weights, convert_gpt2_weights, convert_gptj_weights, + convert_hubert_weights, convert_llama_weights, convert_mingpt_weights, convert_mistral_weights, @@ -59,6 +62,9 @@ "facebook/opt-13b", "facebook/opt-30b", "facebook/opt-66b", + "facebook/hubert-base-ls960", + "facebook/wav2vec2-base", + "facebook/wav2vec2-large", "EleutherAI/gpt-neo-125M", "EleutherAI/gpt-neo-1.3B", "EleutherAI/gpt-neo-2.7B", @@ -610,6 +616,9 @@ "google-bert/bert-base-uncased": ["bert-base-uncased"], "google-bert/bert-large-cased": ["bert-large-cased"], "google-bert/bert-large-uncased": ["bert-large-uncased"], + "facebook/hubert-base-ls960": ["facebook/hubert-base-ls960", "hubert-base-ls960"], + "facebook/wav2vec2-base": ["facebook/wav2vec2-base", "wav2vec2-base", "w2v2-base"], + "facebook/wav2vec2-large": ["facebook/wav2vec2-large", "wav2vec2-large", "w2v2-large"], "roneneldan/TinyStories-1M": ["tiny-stories-1M"], "roneneldan/TinyStories-3M": ["tiny-stories-3M"], "roneneldan/TinyStories-8M": ["tiny-stories-8M"], @@ -1176,6 +1185,51 @@ def convert_hf_model_config(model_name: str, **kwargs: Any): } rotary_pct = hf_config.rotary_pct cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) + elif architecture == "HubertModel": + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + # HuBERT operates on audio frames, not tokens — n_ctx is flexible + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": getattr(hf_config, "hidden_act", "gelu"), + "attention_dir": "bidirectional", + "d_vocab": -1, # no text vocabulary + } + elif "wav2vec2-base" in official_model_name or "wav2vec2-large" in official_model_name: + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + # HuBERT operates on audio frames, not tokens — n_ctx is flexible + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": getattr(hf_config, "hidden_act", "gelu"), + "attention_dir": "bidirectional", + "d_vocab": -1, # no text vocabulary + } + elif architecture == "HubertForCTC": + # Basic transformer configuration + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": getattr(hf_config, "max_position_embeddings", 8192), + "eps": hf_config.layer_norm_eps, + "act_fn": getattr(hf_config, "hidden_act", "gelu"), + "attention_dir": "bidirectional", + # For CTC models: + "d_vocab": hf_config.vocab_size, # text vocab from tokenizer + } elif architecture == "BertForMaskedLM": # All supported Bert architectures have the same config, # so we can use the BertForMaskedLM config for all of them @@ -1921,6 +1975,20 @@ def get_pretrained_state_dict( huggingface_token = os.environ.get("HF_TOKEN", "") if official_model_name in NON_HF_HOSTED_MODEL_NAMES: raise NotImplementedError("Model not hosted on HuggingFace, must pass in hf_model") + elif "hubert" in official_model_name: + hf_model = HubertModel.from_pretrained( + official_model_name, + torch_dtype=dtype, + token=huggingface_token if len(huggingface_token) > 0 else None, + **kwargs, + ) + elif "wav2vec2" in official_model_name: + hf_model = Wav2Vec2Model.from_pretrained( + official_model_name, + torch_dtype=dtype, + token=huggingface_token if len(huggingface_token) > 0 else None, + **kwargs, + ) elif "bert" in official_model_name: hf_model = BertForPreTraining.from_pretrained( official_model_name, @@ -1960,6 +2028,12 @@ def get_pretrained_state_dict( state_dict = convert_neox_weights(hf_model, cfg) elif cfg.original_architecture == "LlamaForCausalLM": state_dict = convert_llama_weights(hf_model, cfg) + elif cfg.original_architecture == "HubertModel": + state_dict = convert_hubert_weights(hf_model, cfg) + elif cfg.original_architecture == "Wav2Vec2Model" or cfg.original_architecture == "Wav2Vec2ForPreTraining": + state_dict = convert_hubert_weights(hf_model, cfg) + elif cfg.original_architecture == "HubertForCTC": + state_dict = convert_hubert_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "T5ForConditionalGeneration": diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..daaffe472 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,4 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .hubert import convert_hubert_weights diff --git a/transformer_lens/pretrained/weight_conversions/hubert.py b/transformer_lens/pretrained/weight_conversions/hubert.py new file mode 100644 index 000000000..a13141725 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/hubert.py @@ -0,0 +1,120 @@ +import einops + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_hubert_weights(hf_model, cfg: HookedTransformerConfig): + """ + Convert transformer encoder weights from a HuggingFace HuBERT model + into the state_dict expected by Transformer-Lens' HookedEncoder. + + Notes: + - This intentionally skips the convolutional frontend and feature_projection. + Those are used directly from the HF model (hf_model.feature_extractor, hf_model.feature_projection). + - Use model.load_state_dict(state_dict, strict=False) to load these. + """ + state_dict = {} + + # Try to find the encoder layer list (different HF variants use .layers or .layer) + encoder = getattr(hf_model, "encoder", None) + if encoder is None: + raise ValueError("hf_model has no .encoder attribute") + + encoder_layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None) + if encoder_layers is None: + # maybe hf_model itself is the encoder (unlikely), or a wrapped attribute + raise ValueError("Couldn't find encoder.layers or encoder.layer on hf_model.encoder") + + # Use cfg dims for reshaping + d_model = cfg.d_model + n_heads = cfg.n_heads + # d_head = d_model // n_heads # implicit if needed + + for l, layer in enumerate(encoder_layers): + # --- Attention module --- + # Some HF variants might call it `attention`, others `self_attn` etc. + att = getattr(layer, "attention", None) or getattr(layer, "self_attn", None) + if att is None: + raise AttributeError(f"Encoder layer {l} has no 'attention' or 'self_attn' attribute") + + # q/k/v/out proj names in HuBERT's HubertAttention: q_proj, k_proj, v_proj, out_proj + # fall back to common alternatives if present + q_w = getattr(att, "q_proj", None) + k_w = getattr(att, "k_proj", None) + v_w = getattr(att, "v_proj", None) + o_w = getattr(att, "out_proj", None) or getattr(att, "proj", None) + + if any(x is None for x in (q_w, k_w, v_w, o_w)): + # Try alternate nested attributes like att.q, att.k, att.v, att.o + q_w = q_w or getattr(att, "q", None) + k_w = k_w or getattr(att, "k", None) + v_w = v_w or getattr(att, "v", None) + o_w = o_w or getattr(att, "o", None) + + if any(x is None for x in (q_w, k_w, v_w, o_w)): + raise AttributeError(f"Could not find q/k/v/out projections in layer {l}. Found: {att}") + + # weights are Linear modules: weight shape (out, in) => same convention as Bert conversion + # reshape to Transformer-Lens expected shapes using einops + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(q_w.weight, "(i h) m -> i m h", i=n_heads) + if q_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(q_w.bias, "(i h) -> i h", i=n_heads) + + state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange(k_w.weight, "(i h) m -> i m h", i=n_heads) + if k_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange(k_w.bias, "(i h) -> i h", i=n_heads) + + state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange(v_w.weight, "(i h) m -> i m h", i=n_heads) + if v_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange(v_w.bias, "(i h) -> i h", i=n_heads) + + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(o_w.weight, "m (i h) -> i h m", i=n_heads) + if o_w.bias is not None: + state_dict[f"blocks.{l}.attn.b_O"] = o_w.bias + + # --- Layer norms inside the layer --- + # HuBERT layer has `layer.layer_norm` and `layer.final_layer_norm` + ln1 = getattr(layer, "layer_norm", None) + ln2 = getattr(layer, "final_layer_norm", None) + if ln1 is None or ln2 is None: + # try alternative names + ln1 = ln1 or getattr(layer, "attention_norm", None) + ln2 = ln2 or getattr(layer, "output_layer_norm", None) + + if ln1 is not None: + state_dict[f"blocks.{l}.ln1.w"] = ln1.weight + state_dict[f"blocks.{l}.ln1.b"] = ln1.bias + if ln2 is not None: + state_dict[f"blocks.{l}.ln2.w"] = ln2.weight + state_dict[f"blocks.{l}.ln2.b"] = ln2.bias + + # --- Feed-forward / MLP --- + # HuBERT uses `feed_forward` which contains intermediate_dense and output_dense + ff = getattr(layer, "feed_forward", None) or getattr(layer, "feedforward", None) or getattr(layer, "ff", None) + if ff is None: + raise AttributeError(f"Layer {l} has no feed_forward/ff attribute") + + # Many implementations name them intermediate_dense and output_dense + fc1 = getattr(ff, "intermediate_dense", None) or getattr(ff, "fc1", None) or getattr(ff, "linear1", None) + fc2 = getattr(ff, "output_dense", None) or getattr(ff, "fc2", None) or getattr(ff, "linear2", None) + + if fc1 is None or fc2 is None: + raise AttributeError(f"Could not find FFN dense layers in layer {l}: {ff}") + + # fc1.weight shape: (d_mlp, d_model) -> Transformer-Lens expects (d_model, d_mlp) + state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(fc1.weight, "mlp model -> model mlp") + if fc1.bias is not None: + state_dict[f"blocks.{l}.mlp.b_in"] = fc1.bias + + # fc2.weight shape: (d_model, d_mlp) -> Transformer-Lens expects (d_mlp, d_model) + state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(fc2.weight, "model mlp -> mlp model") + if fc2.bias is not None: + state_dict[f"blocks.{l}.mlp.b_out"] = fc2.bias + + # --- Optional: encoder-level layer_norm (HubertModel.encoder.layer_norm) --- + if hasattr(hf_model.encoder, "layer_norm"): + ln_final = hf_model.encoder.layer_norm + state_dict["ln_final.w"] = ln_final.weight + state_dict["ln_final.b"] = ln_final.bias + + return state_dict