diff --git a/tests/acceptance/test_activation_cache.py b/tests/acceptance/test_activation_cache.py index 7547f57e1..5770b342b 100644 --- a/tests/acceptance/test_activation_cache.py +++ b/tests/acceptance/test_activation_cache.py @@ -242,11 +242,9 @@ def test_accumulated_resid_with_apply_ln(): # Run the model and cache all activations _, cache = model.run_with_cache(tokens) - # Get accumulated resid and apply ln seperately (cribbed notebook code) + # Get accumulated resid and apply final ln seperately accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1) - ref_scaled_residual_stack = cache.apply_ln_to_stack( - accumulated_residual, layer=-1, pos_slice=-1 - ) + ref_scaled_residual_stack = model.ln_final(accumulated_residual) # Get scaled_residual_stack using apply_ln parameter scaled_residual_stack = cache.accumulated_resid( diff --git a/transformer_lens/ActivationCache.py b/transformer_lens/ActivationCache.py index 25b57a3ee..e1763f7f4 100644 --- a/transformer_lens/ActivationCache.py +++ b/transformer_lens/ActivationCache.py @@ -340,7 +340,23 @@ def accumulated_resid( To project this into the vocabulary space, remember that there is a final layer norm in most decoder-only transformers. Therefore, you need to first apply the final layer norm (which - can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`). + can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`) + and optionally add the unembedding bias (:math:`b_U`). + + **Note on bias terms:** There are two valid approaches for the final projection: + + 1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U` + and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works + correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are + handled consistently. + 2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach, + you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling + into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will + apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are + included. With `fold_ln=False`, the layer norm bias would still be applied, which is + typically not desired when excluding bias terms. + + Both approaches are commonly used in the literature and are valid interpretability choices. If you instead want to look at contributions to the residual stream from each component (e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or @@ -352,11 +368,10 @@ def accumulated_resid( Logit Lens analysis can be done as follows: >>> from transformer_lens import HookedTransformer - >>> from einops import einsum >>> import torch >>> import pandas as pd - >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu") + >>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True) Loaded pretrained model tiny-stories-1M into HookedTransformer >>> prompt = "Why did the chicken cross the" @@ -371,20 +386,24 @@ def accumulated_resid( >>> print(last_token_accum.shape) # layer, d_model torch.Size([9, 64]) + >>> W_U = model.W_U >>> print(W_U.shape) torch.Size([64, 50257]) - >>> layers_unembedded = einsum( - ... last_token_accum, - ... W_U, - ... "layer d_model, d_model d_vocab -> layer d_vocab" - ... ) - >>> print(layers_unembedded.shape) + >>> # Project to vocabulary without unembedding bias + >>> layers_logits = last_token_accum @ W_U # layer, d_vocab + >>> print(layers_logits.shape) + torch.Size([9, 50257]) + + >>> # If you want to apply the unembedding bias, add b_U when present: + >>> # b_U = getattr(model, "b_U", None) + >>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits + >>> # print(layers_logits.shape) torch.Size([9, 50257]) >>> # Get the rank of the correct answer by layer - >>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True) + >>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True) >>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1] >>> print(pd.Series(rank_answer, index=labels)) 0_pre 4442 @@ -408,7 +427,10 @@ def accumulated_resid( incl_mid: Whether to return `resid_mid` for all previous layers. apply_ln: - Whether to apply LayerNorm to the stack. + Whether to apply the final layer norm to the stack. When True, applies + `model.ln_final`, which recomputes normalization statistics (mean and + variance/RMS) for each intermediate state in the stack, transforming the + activations into the format expected by the unembedding layer. pos_slice: A slice object to apply to the pos dimension. Defaults to None, do nothing. mlp_input: @@ -443,9 +465,7 @@ def accumulated_resid( components_list = [pos_slice.apply(c, dim=-2) for c in components_list] components = torch.stack(components_list, dim=0) if apply_ln: - components = self.apply_ln_to_stack( - components, layer, pos_slice=pos_slice, mlp_input=mlp_input - ) + components = self.model.ln_final(components) if return_labels: return components, labels else: