Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tests/acceptance/test_activation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,9 @@ def test_accumulated_resid_with_apply_ln():
# Run the model and cache all activations
_, cache = model.run_with_cache(tokens)

# Get accumulated resid and apply ln seperately (cribbed notebook code)
# Get accumulated resid and apply final ln seperately
accumulated_residual = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1)
ref_scaled_residual_stack = cache.apply_ln_to_stack(
accumulated_residual, layer=-1, pos_slice=-1
)
ref_scaled_residual_stack = model.ln_final(accumulated_residual)

# Get scaled_residual_stack using apply_ln parameter
scaled_residual_stack = cache.accumulated_resid(
Expand Down
48 changes: 34 additions & 14 deletions transformer_lens/ActivationCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,23 @@ def accumulated_resid(

To project this into the vocabulary space, remember that there is a final layer norm in most
decoder-only transformers. Therefore, you need to first apply the final layer norm (which
can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`).
can be done with `apply_ln`), and then multiply by the unembedding matrix (:math:`W_U`)
and optionally add the unembedding bias (:math:`b_U`).

**Note on bias terms:** There are two valid approaches for the final projection:

1. **With bias terms:** Use `model.unembed(normalized_resid)` which applies both :math:`W_U`
and :math:`b_U` (equivalent to `normalized_resid @ model.W_U + model.b_U`). This works
correctly with both `fold_ln=True` and `fold_ln=False` settings, as the biases are
handled consistently.
2. **Without bias terms:** Use only `normalized_resid @ model.W_U`. If taking this approach,
you should instantiate the model with `fold_ln=True`, which folds the layer norm scaling
into :math:`W_U` and the layer norm bias into :math:`b_U`. Since `apply_ln=True` will
apply the (now parameter-free) layer norm, and you skip :math:`b_U`, no bias terms are
included. With `fold_ln=False`, the layer norm bias would still be applied, which is
typically not desired when excluding bias terms.

Both approaches are commonly used in the literature and are valid interpretability choices.

If you instead want to look at contributions to the residual stream from each component
(e.g. for direct logit attribution), see :meth:`decompose_resid` instead, or
Expand All @@ -352,11 +368,10 @@ def accumulated_resid(
Logit Lens analysis can be done as follows:

>>> from transformer_lens import HookedTransformer
>>> from einops import einsum
>>> import torch
>>> import pandas as pd

>>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu")
>>> model = HookedTransformer.from_pretrained("tiny-stories-1M", device="cpu", fold_ln=True)
Loaded pretrained model tiny-stories-1M into HookedTransformer

>>> prompt = "Why did the chicken cross the"
Expand All @@ -371,20 +386,24 @@ def accumulated_resid(
>>> print(last_token_accum.shape) # layer, d_model
torch.Size([9, 64])


>>> W_U = model.W_U
>>> print(W_U.shape)
torch.Size([64, 50257])

>>> layers_unembedded = einsum(
... last_token_accum,
... W_U,
... "layer d_model, d_model d_vocab -> layer d_vocab"
... )
>>> print(layers_unembedded.shape)
>>> # Project to vocabulary without unembedding bias
>>> layers_logits = last_token_accum @ W_U # layer, d_vocab
>>> print(layers_logits.shape)
torch.Size([9, 50257])

>>> # If you want to apply the unembedding bias, add b_U when present:
>>> # b_U = getattr(model, "b_U", None)
>>> # layers_logits = layers_logits + b_U if b_U is not None else layers_logits
>>> # print(layers_logits.shape)
torch.Size([9, 50257])

>>> # Get the rank of the correct answer by layer
>>> sorted_indices = torch.argsort(layers_unembedded, dim=1, descending=True)
>>> sorted_indices = torch.argsort(layers_logits, dim=1, descending=True)
>>> rank_answer = (sorted_indices == 2975).nonzero(as_tuple=True)[1]
>>> print(pd.Series(rank_answer, index=labels))
0_pre 4442
Expand All @@ -408,7 +427,10 @@ def accumulated_resid(
incl_mid:
Whether to return `resid_mid` for all previous layers.
apply_ln:
Whether to apply LayerNorm to the stack.
Whether to apply the final layer norm to the stack. When True, applies
`model.ln_final`, which recomputes normalization statistics (mean and
variance/RMS) for each intermediate state in the stack, transforming the
activations into the format expected by the unembedding layer.
pos_slice:
A slice object to apply to the pos dimension. Defaults to None, do nothing.
mlp_input:
Expand Down Expand Up @@ -443,9 +465,7 @@ def accumulated_resid(
components_list = [pos_slice.apply(c, dim=-2) for c in components_list]
components = torch.stack(components_list, dim=0)
if apply_ln:
components = self.apply_ln_to_stack(
components, layer, pos_slice=pos_slice, mlp_input=mlp_input
)
components = self.model.ln_final(components)
if return_labels:
return components, labels
else:
Expand Down
Loading