diff --git a/.gitignore b/.gitignore
index 32da76df8..be879b728 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,4 +22,4 @@ docs/source/generated
# docs/source/_static/model_table
**.orig
.venv
-
+.env
diff --git a/.vscode/settings.json b/.vscode/settings.json
index 63e6e310a..86d448657 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -33,7 +33,7 @@
"notebook.formatOnSave.enabled": true,
"pylint.importStrategy": "fromEnvironment",
"python.testing.pytestArgs": [
- "transformer_lens",
+ "tests"
],
"python.testing.pytestEnabled": true,
"rewrap.autoWrap.enabled": true,
diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb
index 2e4d09421..97874c19d 100644
--- a/demos/Main_Demo.ipynb
+++ b/demos/Main_Demo.ipynb
@@ -45,15 +45,17 @@
},
{
"cell_type": "code",
- "execution_count": 62,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
+ "\n",
"DEVELOPMENT_MODE = False\n",
"# Detect if we're running in Google Colab\n",
"try:\n",
" import google.colab\n",
+ "\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
"except:\n",
@@ -69,30 +71,24 @@
"# Hot reload in development mode & not running on the CD\n",
"if not IN_COLAB:\n",
" from IPython import get_ipython\n",
+ "\n",
" ip = get_ipython()\n",
" if not ip.extension_manager.loaded:\n",
- " ip.extension_manager.load('autoreload')\n",
+ " ip.extension_manager.load(\"autoreload\")\n",
" %autoreload 2\n",
- " \n",
- "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n"
+ "\n",
+ "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\""
]
},
{
"cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Using renderer: colab\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
"import plotly.io as pio\n",
+ "\n",
"if IN_COLAB or not DEVELOPMENT_MODE:\n",
" pio.renderers.default = \"colab\"\n",
"else:\n",
@@ -102,40 +98,19 @@
},
{
"cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"import circuitsvis as cv\n",
+ "\n",
"# Testing that the library works\n",
"cv.examples.hello(\"Neel\")"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -153,7 +128,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -162,7 +137,7 @@
"from transformer_lens.hook_points import (\n",
" HookPoint,\n",
") # Hooking utilities\n",
- "from transformer_lens import FactoredMatrix\n",
+ "from transformer_lens import FactoredMatrix, HookedTransformer\n",
"from transformer_lens.model_bridge import TransformerBridge"
]
},
@@ -175,20 +150,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"torch.set_grad_enabled(False)"
]
@@ -202,20 +166,28 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
- " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
+ " px.imshow(\n",
+ " utils.to_numpy(tensor),\n",
+ " color_continuous_midpoint=0.0,\n",
+ " color_continuous_scale=\"RdBu\",\n",
+ " labels={\"x\": xaxis, \"y\": yaxis},\n",
+ " **kwargs,\n",
+ " ).show(renderer)\n",
+ "\n",
"\n",
"def line(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n",
- " px.line(utils.to_numpy(tensor), labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n",
+ " px.line(utils.to_numpy(tensor), labels={\"x\": xaxis, \"y\": yaxis}, **kwargs).show(renderer)\n",
+ "\n",
"\n",
"def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n",
" x = utils.to_numpy(x)\n",
" y = utils.to_numpy(y)\n",
- " px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)"
+ " px.scatter(y=y, x=x, labels={\"x\": xaxis, \"y\": yaxis, \"color\": caxis}, **kwargs).show(renderer)"
]
},
{
@@ -254,12 +226,12 @@
"metadata": {},
"outputs": [],
"source": [
- "device = utils.get_device()\n"
+ "device = utils.get_device()"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -282,17 +254,9 @@
},
{
"cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Model loss: tensor(4.1758)\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"model_description_text = \"\"\"## Loading Models\n",
"\n",
@@ -320,17 +284,9 @@
},
{
"cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "cpu\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"gpt2_text = \"Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets.\"\n",
"gpt2_tokens = model.to_tokens(gpt2_text)\n",
@@ -354,18 +310,9 @@
},
{
"cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "torch.Size([12, 33, 33])\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(type(gpt2_cache))\n",
"attention_pattern = gpt2_cache[\"pattern\", 0, \"attn\"]\n",
@@ -375,38 +322,9 @@
},
{
"cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Layer 0 Head Attention Patterns:\n"
- ]
- },
- {
- "data": {
- "text/html": [
- "\n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(\"Layer 0 Head Attention Patterns:\")\n",
"cv.attention.attention_patterns(tokens=gpt2_str_tokens, attention=attention_pattern)"
@@ -421,13 +339,15 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"attn_hook_name = \"blocks.0.attn.hook_pattern\"\n",
"attn_layer = 0\n",
- "_, gpt2_attn_cache = model.run_with_cache(gpt2_tokens, remove_batch_dim=True, stop_at_layer=attn_layer + 1, names_filter=[attn_hook_name])\n",
+ "_, gpt2_attn_cache = model.run_with_cache(\n",
+ " gpt2_tokens, remove_batch_dim=True, stop_at_layer=attn_layer + 1, names_filter=[attn_hook_name]\n",
+ ")\n",
"gpt2_attn = gpt2_attn_cache[attn_hook_name]\n",
"assert torch.allclose(gpt2_attn, attention_pattern)"
]
@@ -472,43 +392,31 @@
},
{
"cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Shape of the value tensor: torch.Size([1, 33, 12, 64])\n",
- "Original Loss: 3.999\n",
- "Ablated Loss: 5.453\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"layer_to_ablate = 0\n",
"head_index_to_ablate = 8\n",
"\n",
+ "\n",
"# We define a head ablation hook\n",
"# The type annotations are NOT necessary, they're just a useful guide to the reader\n",
- "# \n",
+ "#\n",
"def head_ablation_hook(\n",
- " value: Float[torch.Tensor, \"batch pos head_index d_head\"],\n",
- " hook: HookPoint\n",
+ " value: Float[torch.Tensor, \"batch pos head_index d_head\"], hook: HookPoint\n",
") -> Float[torch.Tensor, \"batch pos head_index d_head\"]:\n",
" print(f\"Shape of the value tensor: {value.shape}\")\n",
- " value[:, :, head_index_to_ablate, :] = 0.\n",
+ " value[:, :, head_index_to_ablate, :] = 0.0\n",
" return value\n",
"\n",
+ "\n",
"original_loss = model(gpt2_tokens, return_type=\"loss\")\n",
"ablated_loss = model.run_with_hooks(\n",
- " gpt2_tokens, \n",
- " return_type=\"loss\", \n",
- " fwd_hooks=[(\n",
- " utils.get_act_name(\"v\", layer_to_ablate), \n",
- " head_ablation_hook\n",
- " )]\n",
- " )\n",
+ " gpt2_tokens,\n",
+ " return_type=\"loss\",\n",
+ " fwd_hooks=[(utils.get_act_name(\"v\", layer_to_ablate), head_ablation_hook)],\n",
+ ")\n",
"print(f\"Original Loss: {original_loss.item():.3f}\")\n",
"print(f\"Ablated Loss: {ablated_loss.item():.3f}\")"
]
@@ -551,18 +459,9 @@
},
{
"cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Clean logit difference: 4.276\n",
- "Corrupted logit difference: -2.738\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"clean_prompt = \"After John and Mary went to the store, Mary gave a bottle of milk to\"\n",
"corrupted_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n",
@@ -570,6 +469,7 @@
"clean_tokens = model.to_tokens(clean_prompt)\n",
"corrupted_tokens = model.to_tokens(corrupted_prompt)\n",
"\n",
+ "\n",
"def logits_to_logit_diff(logits, correct_answer=\" John\", incorrect_answer=\" Mary\"):\n",
" # model.to_single_token maps a string value of a single token to the token index for that token\n",
" # If the string is not a single token, it raises an error.\n",
@@ -577,6 +477,7 @@
" incorrect_index = model.to_single_token(incorrect_answer)\n",
" return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]\n",
"\n",
+ "\n",
"# We run on the clean prompt with the cache so we store activations to patch in later.\n",
"clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n",
"clean_logit_diff = logits_to_logit_diff(clean_logits)\n",
@@ -600,34 +501,9 @@
},
{
"cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
- "To disable this warning, you can either:\n",
- "\t- Avoid using `tokenizers` before the fork if possible\n",
- "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "66fe24d5fe7e4ae989e613e0ff1fa394",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- " 0%| | 0/12 [00:00, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"# Note: This cell has odd output behavior in CI\n",
@@ -635,15 +511,14 @@
"# We choose to act on the residual stream at the start of the layer, so we call it resid_pre\n",
"# The type annotations are a guide to the reader and are not necessary\n",
"def residual_stream_patching_hook(\n",
- " resid_pre: Float[torch.Tensor, \"batch pos d_model\"],\n",
- " hook: HookPoint,\n",
- " position: int\n",
+ " resid_pre: Float[torch.Tensor, \"batch pos d_model\"], hook: HookPoint, position: int\n",
") -> Float[torch.Tensor, \"batch pos d_model\"]:\n",
" # Each HookPoint has a name attribute giving the name of the hook.\n",
" clean_resid_pre = clean_cache[hook.name]\n",
" resid_pre[:, position, :] = clean_resid_pre[:, position, :]\n",
" return resid_pre\n",
"\n",
+ "\n",
"# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.\n",
"num_positions = len(clean_tokens[0])\n",
"ioi_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)\n",
@@ -653,13 +528,15 @@
" # Use functools.partial to create a temporary hook function with the position fixed\n",
" temp_hook_fn = partial(residual_stream_patching_hook, position=position)\n",
" # Run the model with the patching hook\n",
- " patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[\n",
- " (utils.get_act_name(\"resid_pre\", layer), temp_hook_fn)\n",
- " ])\n",
+ " patched_logits = model.run_with_hooks(\n",
+ " corrupted_tokens, fwd_hooks=[(utils.get_act_name(\"resid_pre\", layer), temp_hook_fn)]\n",
+ " )\n",
" # Calculate the logit difference\n",
" patched_logit_diff = logits_to_logit_diff(patched_logits).detach()\n",
" # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)\n",
- " ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)"
+ " ioi_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff) / (\n",
+ " clean_logit_diff - corrupted_logit_diff\n",
+ " )"
]
},
{
@@ -672,53 +549,19 @@
},
{
"cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# Add the index to the end of the label, because plotly doesn't like duplicate labels\n",
"token_labels = [f\"{token}_{index}\" for index, token in enumerate(model.to_str_tokens(clean_tokens))]\n",
- "imshow(ioi_patching_result, x=token_labels, xaxis=\"Position\", yaxis=\"Layer\", title=\"Normalized Logit Difference After Patching Residual Stream on the IOI Task\")"
+ "imshow(\n",
+ " ioi_patching_result,\n",
+ " x=token_labels,\n",
+ " xaxis=\"Position\",\n",
+ " yaxis=\"Layer\",\n",
+ " title=\"Normalized Logit Difference After Patching Residual Stream on the IOI Task\",\n",
+ ")"
]
},
{
@@ -758,49 +601,9 @@
},
{
"cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"batch_size = 10\n",
"seq_len = 50\n",
@@ -812,7 +615,12 @@
"repeated_logits = model(repeated_tokens)\n",
"correct_log_probs = model.loss_fn(repeated_logits, repeated_tokens, per_token=True)\n",
"loss_by_position = einops.reduce(correct_log_probs, \"batch position -> position\", \"mean\")\n",
- "line(loss_by_position, xaxis=\"Position\", yaxis=\"Loss\", title=\"Loss by position on random repeated tokens\")"
+ "line(\n",
+ " loss_by_position,\n",
+ " xaxis=\"Position\",\n",
+ " yaxis=\"Loss\",\n",
+ " title=\"Loss by position on random repeated tokens\",\n",
+ ")"
]
},
{
@@ -833,74 +641,38 @@
},
{
"cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.\n",
- "induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)\n",
+ "induction_score_store = torch.zeros(\n",
+ " (model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device\n",
+ ")\n",
+ "\n",
+ "\n",
"def induction_score_hook(\n",
" pattern: Float[torch.Tensor, \"batch head_index dest_pos source_pos\"],\n",
" hook: HookPoint,\n",
"):\n",
" # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back\n",
" # (This only has entries for tokens with index>=seq_len)\n",
- " induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)\n",
+ " induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1 - seq_len)\n",
" # Get an average score per head\n",
- " induction_score = einops.reduce(induction_stripe, \"batch head_index position -> head_index\", \"mean\")\n",
+ " induction_score = einops.reduce(\n",
+ " induction_stripe, \"batch head_index position -> head_index\", \"mean\"\n",
+ " )\n",
" # Store the result.\n",
" induction_score_store[hook.layer(), :] = induction_score\n",
"\n",
+ "\n",
"# We make a boolean filter on activation names, that's true only on attention pattern names.\n",
"pattern_hook_names_filter = lambda name: name.endswith(\"pattern\")\n",
"\n",
"model.run_with_hooks(\n",
- " repeated_tokens, \n",
- " return_type=None, # For efficiency, we don't need to calculate the logits\n",
- " fwd_hooks=[(\n",
- " pattern_hook_names_filter,\n",
- " induction_score_hook\n",
- " )]\n",
+ " repeated_tokens,\n",
+ " return_type=None, # For efficiency, we don't need to calculate the logits\n",
+ " fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)],\n",
")\n",
"\n",
"imshow(induction_score_store, xaxis=\"Head\", yaxis=\"Layer\", title=\"Induction Score by Head\")"
@@ -917,59 +689,42 @@
},
{
"cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- " "
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"if IN_GITHUB:\n",
" torch.manual_seed(50)\n",
- " \n",
+ "\n",
"induction_head_layer = 5\n",
"induction_head_index = 5\n",
"size = (1, 20)\n",
"input_tensor = torch.randint(1000, 10000, size)\n",
"\n",
"single_random_sequence = input_tensor.to(model.cfg.device)\n",
- "repeated_random_sequence = einops.repeat(single_random_sequence, \"batch seq_len -> batch (2 seq_len)\")\n",
+ "repeated_random_sequence = einops.repeat(\n",
+ " single_random_sequence, \"batch seq_len -> batch (2 seq_len)\"\n",
+ ")\n",
+ "\n",
+ "\n",
"def visualize_pattern_hook(\n",
" pattern: Float[torch.Tensor, \"batch head_index dest_pos source_pos\"],\n",
" hook: HookPoint,\n",
"):\n",
" display(\n",
" cv.attention.attention_patterns(\n",
- " tokens=model.to_str_tokens(repeated_random_sequence), \n",
- " attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.\n",
+ " tokens=model.to_str_tokens(repeated_random_sequence),\n",
+ " attention=pattern[0, induction_head_index, :, :][\n",
+ " None, :, :\n",
+ " ], # Add a dummy axis, as CircuitsVis expects 3D patterns.\n",
" )\n",
" )\n",
"\n",
+ "\n",
"model.run_with_hooks(\n",
- " repeated_random_sequence, \n",
- " return_type=None, \n",
- " fwd_hooks=[(\n",
- " utils.get_act_name(\"pattern\", induction_head_layer), \n",
- " visualize_pattern_hook\n",
- " )]\n",
+ " repeated_random_sequence,\n",
+ " return_type=None,\n",
+ " fwd_hooks=[(utils.get_act_name(\"pattern\", induction_head_layer), visualize_pattern_hook)],\n",
")"
]
},
@@ -1000,7 +755,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1010,78 +765,47 @@
},
{
"cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.\n",
- "distilgpt2_induction_score_store = torch.zeros((distilgpt2.cfg.n_layers, distilgpt2.cfg.n_heads), device=distilgpt2.cfg.device)\n",
+ "distilgpt2_induction_score_store = torch.zeros(\n",
+ " (distilgpt2.cfg.n_layers, distilgpt2.cfg.n_heads), device=distilgpt2.cfg.device\n",
+ ")\n",
+ "\n",
+ "\n",
"def induction_score_hook(\n",
" pattern: Float[torch.Tensor, \"batch head_index dest_pos source_pos\"],\n",
" hook: HookPoint,\n",
"):\n",
" # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back\n",
" # (This only has entries for tokens with index>=seq_len)\n",
- " induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)\n",
+ " induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1 - seq_len)\n",
" # Get an average score per head\n",
- " induction_score = einops.reduce(induction_stripe, \"batch head_index position -> head_index\", \"mean\")\n",
+ " induction_score = einops.reduce(\n",
+ " induction_stripe, \"batch head_index position -> head_index\", \"mean\"\n",
+ " )\n",
" # Store the result.\n",
" distilgpt2_induction_score_store[hook.layer(), :] = induction_score\n",
"\n",
+ "\n",
"# We make a boolean filter on activation names, that's true only on attention pattern names.\n",
"pattern_hook_names_filter = lambda name: name.endswith(\"pattern\")\n",
"\n",
"distilgpt2.run_with_hooks(\n",
- " repeated_tokens, \n",
- " return_type=None, # For efficiency, we don't need to calculate the logits\n",
- " fwd_hooks=[(\n",
- " pattern_hook_names_filter,\n",
- " induction_score_hook\n",
- " )]\n",
+ " repeated_tokens,\n",
+ " return_type=None, # For efficiency, we don't need to calculate the logits\n",
+ " fwd_hooks=[(pattern_hook_names_filter, induction_score_hook)],\n",
")\n",
"\n",
- "imshow(distilgpt2_induction_score_store, xaxis=\"Head\", yaxis=\"Layer\", title=\"Induction Score by Head in Distil GPT-2\")"
+ "imshow(\n",
+ " distilgpt2_induction_score_store,\n",
+ " xaxis=\"Head\",\n",
+ " yaxis=\"Layer\",\n",
+ " title=\"Induction Score by Head in Distil GPT-2\",\n",
+ ")"
]
},
{
@@ -1201,30 +925,11 @@
},
{
"cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "blocks.0.attn.W_Q torch.Size([12, 768, 64])\n",
- "blocks.0.attn.W_K torch.Size([12, 768, 64])\n",
- "blocks.0.attn.W_V torch.Size([12, 768, 64])\n",
- "blocks.0.attn.W_O torch.Size([12, 64, 768])\n",
- "blocks.0.attn.b_Q torch.Size([12, 64])\n",
- "blocks.0.attn.b_K torch.Size([12, 64])\n",
- "blocks.0.attn.b_V torch.Size([12, 64])\n",
- "blocks.0.attn.b_O torch.Size([768])\n",
- "blocks.0.mlp.W_in torch.Size([768, 3072])\n",
- "blocks.0.mlp.W_out torch.Size([3072, 768])\n",
- "blocks.0.mlp.b_in torch.Size([3072])\n",
- "blocks.0.mlp.b_out torch.Size([768])\n"
- ]
- }
- ],
- "source": [
- "for name, param in model.named_parameters():\n",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for name, param in model.tl_named_parameters():\n",
" if name.startswith(\"blocks.0.\"):\n",
" print(name, param.shape)"
]
@@ -1238,22 +943,11 @@
},
{
"cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "embed.W_E torch.Size([50257, 768])\n",
- "pos_embed.W_pos torch.Size([1024, 768])\n",
- "unembed.W_U torch.Size([768, 50257])\n",
- "unembed.b_U torch.Size([50257])\n"
- ]
- }
- ],
- "source": [
- "for name, param in model.named_parameters():\n",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for name, param in model.tl_named_parameters():\n",
" if not name.startswith(\"blocks\"):\n",
" print(name, param.shape)"
]
@@ -1276,66 +970,21 @@
},
{
"cell_type": "code",
- "execution_count": 26,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Num tokens: 10\n",
- "embed.hook_in torch.Size([1, 10])\n",
- "hook_embed torch.Size([1, 10, 768])\n",
- "pos_embed.hook_in torch.Size([1, 10])\n",
- "hook_pos_embed torch.Size([1, 10, 768])\n",
- "blocks.0.hook_resid_pre torch.Size([1, 10, 768])\n",
- "blocks.0.ln1.hook_in torch.Size([1, 10, 768])\n",
- "blocks.0.ln1.hook_scale torch.Size([1, 10, 1])\n",
- "blocks.0.ln1.hook_normalized torch.Size([1, 10, 768])\n",
- "blocks.0.ln1.hook_out torch.Size([1, 10, 768])\n",
- "blocks.0.hook_attn_in torch.Size([1, 10, 768])\n",
- "blocks.0.hook_q_input torch.Size([1, 10, 768])\n",
- "blocks.0.attn.hook_q torch.Size([1, 10, 12, 64])\n",
- "blocks.0.hook_k_input torch.Size([1, 10, 768])\n",
- "blocks.0.attn.hook_k torch.Size([1, 10, 12, 64])\n",
- "blocks.0.hook_v_input torch.Size([1, 10, 768])\n",
- "blocks.0.attn.hook_v torch.Size([1, 10, 12, 64])\n",
- "blocks.0.attn.hook_attn_scores torch.Size([1, 12, 10, 10])\n",
- "blocks.0.attn.hook_pattern torch.Size([1, 12, 10, 10])\n",
- "blocks.0.attn.hook_z torch.Size([1, 10, 12, 64])\n",
- "blocks.0.attn.o.hook_out torch.Size([1, 10, 768])\n",
- "blocks.0.attn.hook_hidden_states torch.Size([1, 10, 768])\n",
- "blocks.0.attn.hook_result torch.Size([1, 10, 768])\n",
- "blocks.0.hook_resid_mid torch.Size([1, 10, 768])\n",
- "blocks.0.ln2.hook_scale torch.Size([1, 10, 1])\n",
- "blocks.0.ln2.hook_normalized torch.Size([1, 10, 768])\n",
- "blocks.0.ln2.hook_out torch.Size([1, 10, 768])\n",
- "blocks.0.hook_mlp_in torch.Size([1, 10, 768])\n",
- "blocks.0.mlp.in.hook_in torch.Size([1, 10, 768])\n",
- "blocks.0.mlp.in.hook_in torch.Size([1, 10, 768])\n",
- "blocks.0.mlp.hook_pre torch.Size([1, 10, 3072])\n",
- "blocks.0.mlp.hook_post torch.Size([1, 10, 3072])\n",
- "blocks.0.mlp.out.hook_out torch.Size([1, 10, 768])\n",
- "blocks.0.hook_mlp_out torch.Size([1, 10, 768])\n",
- "blocks.0.mlp.out.hook_out torch.Size([1, 10, 768])\n",
- "blocks.0.hook_resid_post torch.Size([1, 10, 768])\n",
- "ln_final.hook_in torch.Size([1, 10, 768])\n",
- "ln_final.hook_scale torch.Size([1, 10, 1])\n",
- "ln_final.hook_normalized torch.Size([1, 10, 768])\n",
- "ln_final.hook_out torch.Size([1, 10, 768])\n",
- "unembed.hook_in torch.Size([1, 10, 768])\n",
- "hook_unembed torch.Size([1, 10, 50257])\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"test_prompt = \"The quick brown fox jumped over the lazy dog\"\n",
"print(\"Num tokens:\", len(model.to_tokens(test_prompt)[0]))\n",
"\n",
+ "\n",
"def print_name_shape_hook_function(activation, hook):\n",
" print(hook.name, activation.shape)\n",
"\n",
- "not_in_late_block_filter = lambda name: name.startswith(\"blocks.0.\") or not name.startswith(\"blocks\")\n",
+ "\n",
+ "not_in_late_block_filter = lambda name: name.startswith(\"blocks.0.\") or not name.startswith(\n",
+ " \"blocks\"\n",
+ ")\n",
"\n",
"model.run_with_hooks(\n",
" test_prompt,\n",
@@ -1382,7 +1031,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -1392,59 +1041,9 @@
},
{
"cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Top 20 values\n",
- "7.03 ','\n",
- "6.98 ' the'\n",
- "6.68 ' and'\n",
- "6.49 '.'\n",
- "6.48 '\\n'\n",
- "6.47 ' a'\n",
- "6.41 ' in'\n",
- "6.25 ' to'\n",
- "6.16 ' of'\n",
- "6.04 '-'\n",
- "6.03 ' ('\n",
- "5.88 ' \"'\n",
- "5.80 ' for'\n",
- "5.72 ' that'\n",
- "5.64 ' on'\n",
- "5.59 ' is'\n",
- "5.52 ' as'\n",
- "5.49 ' at'\n",
- "5.45 ' with'\n",
- "5.44 ' or'\n",
- "...\n",
- "Bottom 20 values\n",
- "-3.82 ' サーティ'\n",
- "-3.83 '\\x18'\n",
- "-3.83 '\\x14'\n",
- "-3.83 ' RandomRedditor'\n",
- "-3.83 '龍�'\n",
- "-3.83 '�'\n",
- "-3.83 '\\x1b'\n",
- "-3.83 '�'\n",
- "-3.83 '\\x05'\n",
- "-3.83 '\\x00'\n",
- "-3.83 '\\x06'\n",
- "-3.83 '\\x07'\n",
- "-3.83 '\\x0c'\n",
- "-3.83 '\\x02'\n",
- "-3.83 'oreAndOnline'\n",
- "-3.84 '\\x11'\n",
- "-3.84 '�'\n",
- "-3.84 '\\x10'\n",
- "-3.84 '�'\n",
- "-3.84 '�'\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"top_k = 20\n",
"print(f\"Top {top_k} values\")\n",
@@ -1466,22 +1065,12 @@
},
{
"cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "John bias: 2.8995\n",
- "Mary bias: 1.6034\n",
- "Prob ratio bias: 3.6550x\n"
- ]
- }
- ],
- "source": [
- "john_bias = model.unembed.b_U[model.to_single_token(' John')]\n",
- "mary_bias = model.unembed.b_U[model.to_single_token(' Mary')]\n",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "john_bias = model.unembed.b_U[model.to_single_token(\" John\")]\n",
+ "mary_bias = model.unembed.b_U[model.to_single_token(\" Mary\")]\n",
"\n",
"print(f\"John bias: {john_bias.item():.4f}\")\n",
"print(f\"Mary bias: {mary_bias.item():.4f}\")\n",
@@ -1527,17 +1116,9 @@
},
{
"cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['<|endoftext|>', 'The', ' first', ' thing', ' you', ' need', ' to', ' figure', ' out', ' is', ' *', 'how', '*', ' things', ' are', ' token', 'ized', '.', ' `', 'model', '.', 'to', '_', 'str', '_', 't', 'ok', 'ens', '`', ' splits', ' a', ' string', ' into', ' the', ' tokens', ' *', 'as', ' a', ' list', ' of', ' sub', 'strings', '*,', ' and', ' so', ' lets', ' you', ' explore', ' what', ' the', ' text', ' looks', ' like', '.', ' To', ' demonstrate', ' this', ',', ' let', \"'s\", ' use', ' it', ' on', ' this', ' paragraph', '.']\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"example_text = \"The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.\"\n",
"example_text_str_tokens = model.to_str_tokens(example_text)\n",
@@ -1553,23 +1134,9 @@
},
{
"cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[50256, 464, 717, 1517, 345, 761, 284, 3785, 503, 318,\n",
- " 1635, 4919, 9, 1243, 389, 11241, 1143, 13, 4600, 19849,\n",
- " 13, 1462, 62, 2536, 62, 83, 482, 641, 63, 30778,\n",
- " 257, 4731, 656, 262, 16326, 1635, 292, 257, 1351, 286,\n",
- " 850, 37336, 25666, 290, 523, 8781, 345, 7301, 644, 262,\n",
- " 2420, 3073, 588, 13, 1675, 10176, 428, 11, 1309, 338,\n",
- " 779, 340, 319, 428, 7322, 13]])\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"example_text_tokens = model.to_tokens(example_text)\n",
"print(example_text_tokens)"
@@ -1586,18 +1153,9 @@
},
{
"cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[50256, 464, 3797, 3332, 319, 262, 2603, 13, 50256, 50256],\n",
- " [50256, 464, 3797, 3332, 319, 262, 2603, 1107, 1327, 13]])\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"example_multi_text = [\"The cat sat on the mat.\", \"The cat sat on the mat really hard.\"]\n",
"example_multi_text_tokens = model.to_tokens(example_multi_text)\n",
@@ -1622,18 +1180,9 @@
},
{
"cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Probability tensor shape [batch, position, d_vocab] == torch.Size([1, 8, 50257])\n",
- "| The| probability: 11.98%\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"cat_text = \"The cat sat on the mat.\"\n",
"cat_logits = model(cat_text)\n",
@@ -1655,21 +1204,12 @@
},
{
"cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Token 256 - the most common pair of ASCII characters: | t|\n",
- "De-Tokenizing the example tokens: <|endoftext|>The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(f\"Token 256 - the most common pair of ASCII characters: |{model.to_string(256)}|\")\n",
- "# Squeeze means to remove dimensions of length 1. \n",
+ "# Squeeze means to remove dimensions of length 1.\n",
"# Here, that removes the dummy batch dimension so it's a rank 1 tensor and returns a string\n",
"# Rank 2 tensors map to a list of strings\n",
"print(f\"De-Tokenizing the example tokens: {model.to_string(example_text_tokens.squeeze())}\")"
@@ -1686,18 +1226,9 @@
},
{
"cell_type": "code",
- "execution_count": 35,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "With BOS: 2\n",
- "Without BOS: 1\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(\"With BOS:\", model.get_token_position(\" cat\", \"The cat sat on the mat\"))\n",
"print(\"Without BOS:\", model.get_token_position(\" cat\", \"The cat sat on the mat\", prepend_bos=False))"
@@ -1712,27 +1243,22 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "First occurrence 2\n",
- "Final occurrence 13\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "print(\"First occurrence\", model.get_token_position(\n",
- " \" cat\", \n",
- " \"The cat sat on the mat. The mat sat on the cat.\", \n",
- " mode=\"first\"))\n",
- "print(\"Final occurrence\", model.get_token_position(\n",
- " \" cat\", \n",
- " \"The cat sat on the mat. The mat sat on the cat.\", \n",
- " mode=\"last\"))"
+ "print(\n",
+ " \"First occurrence\",\n",
+ " model.get_token_position(\n",
+ " \" cat\", \"The cat sat on the mat. The mat sat on the cat.\", mode=\"first\"\n",
+ " ),\n",
+ ")\n",
+ "print(\n",
+ " \"Final occurrence\",\n",
+ " model.get_token_position(\n",
+ " \" cat\", \"The cat sat on the mat. The mat sat on the cat.\", mode=\"last\"\n",
+ " ),\n",
+ ")"
]
},
{
@@ -1744,18 +1270,9 @@
},
{
"cell_type": "code",
- "execution_count": 37,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['<|endoftext|>', '23', '42', '+', '2017', '=', '214', '45']\n",
- "['<|endoftext|>', '1000', '+', '1', '000000', '=', '9999', '99']\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(model.to_str_tokens(\"2342+2017=21445\"))\n",
"print(model.to_str_tokens(\"1000+1000000=999999\"))"
@@ -1788,19 +1305,9 @@
},
{
"cell_type": "code",
- "execution_count": 38,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Logits shape by default (with BOS) torch.Size([1, 3, 50257])\n",
- "Logits shape with BOS torch.Size([1, 3, 50257])\n",
- "Logits shape without BOS - only 2 positions! torch.Size([1, 2, 50257])\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(\"Logits shape by default (with BOS)\", model(\"Hello World\").shape)\n",
"print(\"Logits shape with BOS\", model(\"Hello World\", prepend_bos=True).shape)\n",
@@ -1822,25 +1329,20 @@
},
{
"cell_type": "code",
- "execution_count": 39,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Logit difference with BOS: 6.754\n",
- "Logit difference without BOS: 2.782\n"
- ]
- }
- ],
- "source": [
- "ioi_logits_with_bos = model(\"Claire and Mary went to the shops, then Mary gave a bottle of milk to\", prepend_bos=True)\n",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ioi_logits_with_bos = model(\n",
+ " \"Claire and Mary went to the shops, then Mary gave a bottle of milk to\", prepend_bos=True\n",
+ ")\n",
"mary_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(\" Mary\")].item()\n",
"claire_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(\" Claire\")].item()\n",
"print(f\"Logit difference with BOS: {(claire_logit_with_bos - mary_logit_with_bos):.3f}\")\n",
"\n",
- "ioi_logits_without_bos = model(\"Claire and Mary went to the shops, then Mary gave a bottle of milk to\", prepend_bos=False)\n",
+ "ioi_logits_without_bos = model(\n",
+ " \"Claire and Mary went to the shops, then Mary gave a bottle of milk to\", prepend_bos=False\n",
+ ")\n",
"mary_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(\" Mary\")].item()\n",
"claire_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(\" Claire\")].item()\n",
"print(f\"Logit difference without BOS: {(claire_logit_without_bos - mary_logit_without_bos):.3f}\")"
@@ -1855,18 +1357,9 @@
},
{
"cell_type": "code",
- "execution_count": 40,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "| Claire| -> [' Claire']\n",
- "|Claire| -> ['Cl', 'aire']\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(f\"| Claire| -> {model.to_str_tokens(' Claire', prepend_bos=False)}\")\n",
"print(f\"|Claire| -> {model.to_str_tokens('Claire', prepend_bos=False)}\")"
@@ -1908,20 +1401,9 @@
},
{
"cell_type": "code",
- "execution_count": 41,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Norms:\n",
- "tensor(9.9105)\n",
- "tensor(9.9105)\n",
- "Right dimension: 5, Left dimension: 5, Hidden dimension: 2\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"if IN_GITHUB:\n",
" torch.manual_seed(50)\n",
@@ -1934,7 +1416,9 @@
"print(AB.norm())\n",
"print(AB_factor.norm())\n",
"\n",
- "print(f\"Right dimension: {AB_factor.rdim}, Left dimension: {AB_factor.ldim}, Hidden dimension: {AB_factor.mdim}\")"
+ "print(\n",
+ " f\"Right dimension: {AB_factor.rdim}, Left dimension: {AB_factor.ldim}, Hidden dimension: {AB_factor.mdim}\"\n",
+ ")"
]
},
{
@@ -1946,24 +1430,9 @@
},
{
"cell_type": "code",
- "execution_count": 42,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Eigenvalues:\n",
- "tensor([-6.2877e+00+0.j, -1.1103e-07+0.j, 2.3121e+00+0.j, -1.7900e-07+0.j,\n",
- " 9.0581e-08+0.j])\n",
- "tensor([-6.2877+0.j, 2.3121+0.j])\n",
- "\n",
- "Singular Values:\n",
- "tensor([8.3126e+00, 5.3963e+00, 3.2166e-07, 1.2748e-07, 1.9762e-08])\n",
- "tensor([8.3126, 5.3963])\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"print(\"Eigenvalues:\")\n",
@@ -1984,30 +1453,22 @@
},
{
"cell_type": "code",
- "execution_count": 43,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Unfactored: torch.Size([5, 300]) tensor(160.0830)\n",
- "Factored: torch.Size([5, 300]) tensor(160.0830)\n",
- "Right dimension: 300, Left dimension: 5, Hidden dimension: 2\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"if IN_GITHUB:\n",
" torch.manual_seed(50)\n",
- " \n",
+ "\n",
"C = torch.randn(5, 300)\n",
"\n",
"ABC = AB @ C\n",
"ABC_factor = AB_factor @ C\n",
"print(\"Unfactored:\", ABC.shape, ABC.norm().round(decimals=3))\n",
"print(\"Factored:\", ABC_factor.shape, ABC_factor.norm().round(decimals=3))\n",
- "print(f\"Right dimension: {ABC_factor.rdim}, Left dimension: {ABC_factor.ldim}, Hidden dimension: {ABC_factor.mdim}\")"
+ "print(\n",
+ " f\"Right dimension: {ABC_factor.rdim}, Left dimension: {ABC_factor.ldim}, Hidden dimension: {ABC_factor.mdim}\"\n",
+ ")"
]
},
{
@@ -2019,17 +1480,9 @@
},
{
"cell_type": "code",
- "execution_count": 44,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor(True)\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"AB_unfactored = AB_factor.AB\n",
"print(torch.isclose(AB_unfactored, AB).all())"
@@ -2062,17 +1515,9 @@
},
{
"cell_type": "code",
- "execution_count": 45,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "FactoredMatrix: Shape(torch.Size([12, 12, 768, 768])), Hidden Dim(64)\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"OV_circuit_all_heads = model.OV\n",
"print(OV_circuit_all_heads)"
@@ -2080,72 +1525,32 @@
},
{
"cell_type": "code",
- "execution_count": 46,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([12, 12, 64])\n",
- "torch.complex64\n"
- ]
- }
- ],
- "source": [
- "OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues \n",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues\n",
"print(OV_circuit_all_heads_eigenvalues.shape)\n",
"print(OV_circuit_all_heads_eigenvalues.dtype)"
]
},
{
"cell_type": "code",
- "execution_count": 47,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)\n",
- "imshow(utils.to_numpy(OV_copying_score), xaxis=\"Head\", yaxis=\"Layer\", title=\"OV Copying Score for each head in GPT-2 Small\", zmax=1.0, zmin=-1.0)"
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(\n",
+ " dim=-1\n",
+ ").real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)\n",
+ "imshow(\n",
+ " utils.to_numpy(OV_copying_score),\n",
+ " xaxis=\"Head\",\n",
+ " yaxis=\"Layer\",\n",
+ " title=\"OV Copying Score for each head in GPT-2 Small\",\n",
+ " zmax=1.0,\n",
+ " zmin=-1.0,\n",
+ ")"
]
},
{
@@ -2157,51 +1562,17 @@
},
{
"cell_type": "code",
- "execution_count": 48,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "scatter(x=OV_circuit_all_heads_eigenvalues[-1, -1, :].real, y=OV_circuit_all_heads_eigenvalues[-1, -1, :].imag, title=\"Eigenvalues of Head L11H11 of GPT-2 Small\", xaxis=\"Real\", yaxis=\"Imaginary\")"
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "scatter(\n",
+ " x=OV_circuit_all_heads_eigenvalues[-1, -1, :].real,\n",
+ " y=OV_circuit_all_heads_eigenvalues[-1, -1, :].imag,\n",
+ " title=\"Eigenvalues of Head L11H11 of GPT-2 Small\",\n",
+ " xaxis=\"Real\",\n",
+ " yaxis=\"Imaginary\",\n",
+ ")"
]
},
{
@@ -2213,17 +1584,9 @@
},
{
"cell_type": "code",
- "execution_count": 49,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "FactoredMatrix: Shape(torch.Size([12, 12, 50257, 50257])), Hidden Dim(64)\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"full_OV_circuit = model.embed.W_E @ OV_circuit_all_heads @ model.unembed.W_U\n",
"print(full_OV_circuit)"
@@ -2231,18 +1594,9 @@
},
{
"cell_type": "code",
- "execution_count": 50,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([12, 12, 64])\n",
- "torch.complex64\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"full_OV_circuit_eigenvalues = full_OV_circuit.eigenvalues\n",
"print(full_OV_circuit_eigenvalues.shape)\n",
@@ -2251,52 +1605,21 @@
},
{
"cell_type": "code",
- "execution_count": 51,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "full_OV_copying_score = full_OV_circuit_eigenvalues.sum(dim=-1).real / full_OV_circuit_eigenvalues.abs().sum(dim=-1)\n",
- "imshow(utils.to_numpy(full_OV_copying_score), xaxis=\"Head\", yaxis=\"Layer\", title=\"OV Copying Score for each head in GPT-2 Small\", zmax=1.0, zmin=-1.0)"
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "full_OV_copying_score = full_OV_circuit_eigenvalues.sum(\n",
+ " dim=-1\n",
+ ").real / full_OV_circuit_eigenvalues.abs().sum(dim=-1)\n",
+ "imshow(\n",
+ " utils.to_numpy(full_OV_copying_score),\n",
+ " xaxis=\"Head\",\n",
+ " yaxis=\"Layer\",\n",
+ " title=\"OV Copying Score for each head in GPT-2 Small\",\n",
+ " zmax=1.0,\n",
+ " zmin=-1.0,\n",
+ ")"
]
},
{
@@ -2308,70 +1631,28 @@
},
{
"cell_type": "code",
- "execution_count": 52,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "scatter(x=full_OV_copying_score.flatten(), y=OV_copying_score.flatten(), hover_name=[f\"L{layer}H{head}\" for layer in range(12) for head in range(12)], title=\"OV Copying Score for each head in GPT-2 Small\", xaxis=\"Full OV Copying Score\", yaxis=\"OV Copying Score\")"
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "scatter(\n",
+ " x=full_OV_copying_score.flatten(),\n",
+ " y=OV_copying_score.flatten(),\n",
+ " hover_name=[f\"L{layer}H{head}\" for layer in range(12) for head in range(12)],\n",
+ " title=\"OV Copying Score for each head in GPT-2 Small\",\n",
+ " xaxis=\"Full OV Copying Score\",\n",
+ " yaxis=\"OV Copying Score\",\n",
+ ")"
]
},
{
"cell_type": "code",
- "execution_count": 53,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Token 256 - the most common pair of ASCII characters: | t|\n",
- "De-Tokenizing the example tokens: <|endoftext|>The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"print(f\"Token 256 - the most common pair of ASCII characters: |{model.to_string(256)}|\")\n",
- "# Squeeze means to remove dimensions of length 1. \n",
+ "# Squeeze means to remove dimensions of length 1.\n",
"# Here, that removes the dummy batch dimension so it's a rank 1 tensor and returns a string\n",
"# Rank 2 tensors map to a list of strings\n",
"print(f\"De-Tokenizing the example tokens: {model.to_string(example_text_tokens.squeeze())}\")"
@@ -2394,23 +1675,17 @@
},
{
"cell_type": "code",
- "execution_count": 54,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'(CNN) President Barack Obama caught in embarrassing new scandal\\n\\nA top FBI official said the FBI will review all surveillance requests made by President Obama in recent months as part of its probe into possible ties between the Russian government and his campaign during the 2016 presidential campaign.\\n\\nThe official, who spoke to CNN'"
- ]
- },
- "execution_count": 54,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
- "model.generate(\"(CNN) President Barack Obama caught in embarrassing new scandal\\n\", max_new_tokens=50, temperature=0.7, prepend_bos=True)"
+ "model.generate(\n",
+ " \"(CNN) President Barack Obama caught in embarrassing new scandal\\n\",\n",
+ " max_new_tokens=50,\n",
+ " temperature=0.7,\n",
+ " prepend_bos=True,\n",
+ ")"
]
},
{
@@ -2459,11 +1734,10 @@
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "\n",
"from transformer_lens.hook_points import HookedRootModule, HookPoint\n",
"\n",
"\n",
@@ -2502,7 +1776,7 @@
" return x_out\n",
"\n",
"\n",
- "model = TwoLayerModel()\n"
+ "model = TwoLayerModel()"
]
},
{
@@ -2517,29 +1791,14 @@
},
{
"cell_type": "code",
- "execution_count": 56,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Model output: 780.0\n",
- "Value cached at hook hook_in 5.0\n",
- "Value cached at hook layer1.hook_square 25.0\n",
- "Value cached at hook hook_mid 28.0\n",
- "Value cached at hook layer2.hook_square 784.0\n",
- "Value cached at hook hook_out 780.0\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
- "\n",
"out, cache = model.run_with_cache(torch.tensor(5.0))\n",
"print(\"Model output:\", out.item())\n",
"for key in cache:\n",
- " print(f\"Value cached at hook {key}\", cache[key].item())\n",
- "\n"
+ " print(f\"Value cached at hook {key}\", cache[key].item())"
]
},
{
@@ -2552,20 +1811,10 @@
},
{
"cell_type": "code",
- "execution_count": 57,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "layer2.hook_square\n",
- "Output after intervening on layer2.hook_scaled -4.0\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "\n",
"def set_to_zero_hook(tensor, hook):\n",
" print(hook.name)\n",
" return torch.tensor(0.0)\n",
@@ -2619,213 +1868,32 @@
},
{
"cell_type": "code",
- "execution_count": 58,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- },
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"from transformer_lens.loading_from_pretrained import get_checkpoint_labels\n",
+ "\n",
"for model_name in [\"attn-only-2l\", \"solu-12l\", \"stanford-gpt2-small-a\"]:\n",
" checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(model_name)\n",
- " line(checkpoint_labels, xaxis=\"Checkpoint Index\", yaxis=f\"Checkpoint Value ({checkpoint_label_type})\", title=f\"Checkpoint Values for {model_name} (Log scale)\", log_y=True, markers=True)\n",
+ " line(\n",
+ " checkpoint_labels,\n",
+ " xaxis=\"Checkpoint Index\",\n",
+ " yaxis=f\"Checkpoint Value ({checkpoint_label_type})\",\n",
+ " title=f\"Checkpoint Values for {model_name} (Log scale)\",\n",
+ " log_y=True,\n",
+ " markers=True,\n",
+ " )\n",
"for model_name in [\"solu-1l-pile\", \"solu-6l-pile\"]:\n",
" checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(model_name)\n",
- " line(checkpoint_labels, xaxis=\"Checkpoint Index\", yaxis=f\"Checkpoint Value ({checkpoint_label_type})\", title=f\"Checkpoint Values for {model_name} (Linear scale)\", log_y=False, markers=True)"
+ " line(\n",
+ " checkpoint_labels,\n",
+ " xaxis=\"Checkpoint Index\",\n",
+ " yaxis=f\"Checkpoint Value ({checkpoint_label_type})\",\n",
+ " title=f\"Checkpoint Values for {model_name} (Linear scale)\",\n",
+ " log_y=False,\n",
+ " markers=True,\n",
+ " )"
]
},
{
@@ -2859,11 +1927,12 @@
},
{
"cell_type": "code",
- "execution_count": 59,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformer_lens import evals\n",
+ "\n",
"# We use the two layer model with SoLU activations, chosen fairly arbitrarily as being both small (so fast to download and keep in memory) and pretty good at the induction task.\n",
"model_name = \"solu-2l\"\n",
"# We can load a model from a checkpoint by specifying the checkpoint_index, -1 means the final checkpoint\n",
@@ -2882,20 +1951,24 @@
},
{
"cell_type": "code",
- "execution_count": 60,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not IN_GITHUB:\n",
" for index in checkpoint_indices:\n",
" # Load the model from the relevant checkpoint by index\n",
- " model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)\n",
+ " model_for_this_checkpoint = HookedTransformer.from_pretrained(\n",
+ " model_name, checkpoint_index=index, device=device\n",
+ " )\n",
" checkpointed_models.append(model_for_this_checkpoint)\n",
"\n",
" tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value\n",
" tokens_trained_on.append(tokens_seen_for_this_checkpoint)\n",
"\n",
- " induction_loss_for_this_checkpoint = evals.induction_loss(model_for_this_checkpoint, device=device).item()\n",
+ " induction_loss_for_this_checkpoint = evals.induction_loss(\n",
+ " model_for_this_checkpoint, device=device\n",
+ " ).item()\n",
" induction_losses.append(induction_loss_for_this_checkpoint)"
]
},
@@ -2910,51 +1983,19 @@
},
{
"cell_type": "code",
- "execution_count": 61,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "\n",
- "\n",
- " \n",
- "\n",
- ""
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "line(induction_losses, x=tokens_trained_on, xaxis=\"Tokens Trained On\", yaxis=\"Induction Loss\", title=\"Induction Loss over training: solu-2l\", markers=True, log_x=True)"
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "line(\n",
+ " induction_losses,\n",
+ " x=tokens_trained_on,\n",
+ " xaxis=\"Tokens Trained On\",\n",
+ " yaxis=\"Induction Loss\",\n",
+ " title=\"Induction Loss over training: solu-2l\",\n",
+ " markers=True,\n",
+ " log_x=True,\n",
+ ")"
]
}
],
diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py
index d0f746d60..797ecbbf9 100644
--- a/tests/acceptance/test_hooked_encoder.py
+++ b/tests/acceptance/test_hooked_encoder.py
@@ -225,6 +225,6 @@ def test_input_list_of_strings_mlm(our_bert, huggingface_bert, tokenizer):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device")
-def test_cuda(mlm_tokens):
+def test_cuda(tokens):
model = HookedEncoder.from_pretrained(MODEL_NAME)
- model(mlm_tokens)
+ model(tokens)
diff --git a/tests/acceptance/test_multi_gpu.py b/tests/acceptance/test_multi_gpu.py
index 3af5eeeb2..ad407eb6e 100644
--- a/tests/acceptance/test_multi_gpu.py
+++ b/tests/acceptance/test_multi_gpu.py
@@ -111,7 +111,7 @@ def test_cache_device():
torch.device("cuda:1")
)
- logits, cache = model.run_with_cache("Hello there", device="cpu")
+ logits, cache = model.run_with_cache("Hello there", device=torch.device("cpu"))
assert norm_device(cache["blocks.0.mlp.hook_post"].device) == norm_device(torch.device("cpu"))
model.to("cuda")
diff --git a/tests/integration/model_bridge/test_optimizer_compatibility.py b/tests/integration/model_bridge/test_optimizer_compatibility.py
new file mode 100644
index 000000000..bffdb1428
--- /dev/null
+++ b/tests/integration/model_bridge/test_optimizer_compatibility.py
@@ -0,0 +1,334 @@
+"""Integration test for TransformerBridge optimizer compatibility.
+
+Tests that TransformerBridge works correctly with PyTorch optimizers,
+including parameter access, gradient flow, and parameter updates.
+"""
+
+from dataclasses import dataclass
+from typing import NamedTuple
+
+import torch
+
+from transformer_lens.model_bridge.bridge import TransformerBridge
+
+
+class StageThresholds(NamedTuple):
+ """Thresholds for a specific stage of validation."""
+
+ logits_max: float = 0.0
+ logits_mean: float = 0.0
+ loss_relative: float = 0.0
+ params_max: float = 0.0 # Only used for parameter update stages
+ params_mean: float = 0.0 # Only used for parameter update stages
+
+
+@dataclass
+class StepThresholds:
+ """Thresholds for all stages at a specific optimization step."""
+
+ step: int
+ initial_fwd: StageThresholds
+ post_update_fwd: StageThresholds
+ param_update: StageThresholds # Tracks parameter divergence after update
+
+
+def test_optimizer_workflow():
+ """Test complete optimizer workflow with TransformerBridge."""
+ # Load model
+ bridge = TransformerBridge.boot_transformers("distilgpt2")
+
+ # Verify parameters() returns leaf tensors
+ params = list(bridge.parameters())
+ assert len(params) > 0, "Should have parameters"
+ assert all(p.is_leaf for p in params), "All parameters should be leaf tensors"
+
+ # Verify optimizer creation succeeds
+ optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-4)
+ assert optimizer is not None, "Optimizer should be created successfully"
+
+ # Verify tl_parameters() returns TL-style dict
+ tl_params = bridge.tl_parameters()
+ assert len(tl_params) > 0, "Should have TL-style parameters"
+ assert any(
+ "blocks." in name and ".attn." in name for name in tl_params.keys()
+ ), "Should have TL-style parameter names like 'blocks.0.attn.W_Q'"
+
+ # Verify tl_named_parameters() iterator matches dict
+ tl_named_params = list(bridge.tl_named_parameters())
+ assert len(tl_named_params) == len(
+ tl_params
+ ), "Iterator should yield same number of parameters as dict"
+ iterator_dict = dict(tl_named_params)
+ for name, tensor in tl_params.items():
+ assert name in iterator_dict, f"Name {name} should be in iterator output"
+ assert torch.equal(iterator_dict[name], tensor), f"Tensor for {name} should match"
+
+ # Verify named_parameters() returns HF-style names
+ hf_names = [name for name, _ in bridge.named_parameters()]
+ assert len(hf_names) > 0, "Should have HF-style parameters"
+ assert any(
+ "_original_component" in name for name in hf_names
+ ), "Should have HuggingFace-style parameter names"
+
+ # Verify forward pass and backward work
+ device = next(bridge.parameters()).device
+ input_ids = torch.randint(0, bridge.cfg.d_vocab, (1, 10), device=device)
+ logits = bridge(input_ids)
+ expected_shape = (1, 10, bridge.cfg.d_vocab)
+ assert logits.shape == expected_shape, f"Expected shape {expected_shape}, got {logits.shape}"
+
+ loss = logits[0, -1].sum()
+ loss.backward()
+
+ # Verify gradients were computed
+ params_with_grad = [p for p in bridge.parameters() if p.grad is not None]
+ assert len(params_with_grad) > 0, "Should have parameters with gradients after backward()"
+
+ # Verify optimizer step updates parameters
+ param_before = list(bridge.parameters())[0].clone()
+ optimizer.step()
+ param_after = list(bridge.parameters())[0]
+ assert not torch.allclose(
+ param_before, param_after
+ ), "Parameters should be updated after optimizer.step()"
+
+
+def test_optimizer_compatibility_after_compatibility_mode():
+ """Test that optimizer still works after enabling compatibility mode."""
+ bridge = TransformerBridge.boot_transformers("distilgpt2")
+ bridge.enable_compatibility_mode(no_processing=True)
+
+ # Verify parameters are still leaf tensors after compatibility mode
+ params = list(bridge.parameters())
+ assert all(
+ p.is_leaf for p in params
+ ), "All parameters should still be leaf tensors after compatibility mode"
+
+ # Verify optimizer works after compatibility mode
+ optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-4)
+ device = next(bridge.parameters()).device
+ input_ids = torch.randint(0, bridge.cfg.d_vocab, (1, 10), device=device)
+
+ logits = bridge(input_ids)
+ loss = logits[0, -1].sum()
+ loss.backward()
+ optimizer.step()
+
+ # If we got here without errors, the test passed
+ assert True, "Optimizer should work after compatibility mode"
+
+
+def test_bridge_hooked_parity_multi_step_optimization():
+ """Test parity between Bridge and HookedTransformer across multiple optimization steps.
+
+ This test validates that both architectures maintain comparable results over
+ multiple optimization steps (1, 10), checking:
+ - Initial forward pass: logits and loss alignment before any updates
+ - Post-update forward pass: logits and loss remain close after each step
+ - Parameter updates: unembed weights remain close after each step
+
+ We focus on the unembed layer as it's a directly comparable component between
+ both architectures with matching shapes.
+ """
+ from transformer_lens import HookedTransformer
+
+ # Define thresholds for each step (rounded to next magnitude above observed + 30%)
+ step_thresholds = [
+ StepThresholds(
+ step=1,
+ initial_fwd=StageThresholds(logits_max=1e-3, logits_mean=1e-4, loss_relative=1e-6),
+ post_update_fwd=StageThresholds(logits_max=2.0, logits_mean=1e-3, loss_relative=1e-5),
+ param_update=StageThresholds(params_max=1e-2, params_mean=1e-6),
+ ),
+ StepThresholds(
+ step=10,
+ initial_fwd=StageThresholds(logits_max=20.0, logits_mean=0.1, loss_relative=1e-3),
+ post_update_fwd=StageThresholds(logits_max=20.0, logits_mean=0.1, loss_relative=1e-4),
+ param_update=StageThresholds(params_max=1e-1, params_mean=1e-5),
+ ),
+ ]
+
+ # Set seed for reproducibility
+ torch.manual_seed(42)
+
+ # Load both models with no weight processing for fair comparison
+ hooked = HookedTransformer.from_pretrained(
+ "distilgpt2",
+ device="cpu",
+ fold_ln=False,
+ center_writing_weights=False,
+ center_unembed=False,
+ fold_value_biases=False,
+ refactor_factored_attn_matrices=False,
+ )
+
+ bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
+ bridge.enable_compatibility_mode(no_processing=True)
+
+ # Create optimizers with same settings
+ hooked_optimizer = torch.optim.AdamW(hooked.parameters(), lr=1e-3)
+ bridge_optimizer = torch.optim.AdamW(bridge.parameters(), lr=1e-3)
+
+ # Create identical input with fixed seed
+ torch.manual_seed(42)
+ input_ids = torch.randint(0, bridge.cfg.d_vocab, (1, 10), device="cpu")
+
+ # Access unembed parameters for comparison
+ hooked_unembed_param = hooked.unembed.W_U
+ bridge_unembed_param = bridge.unembed._original_component.weight
+
+ # Verify shapes are compatible after transpose
+ assert hooked_unembed_param.T.shape == bridge_unembed_param.shape, (
+ f"Unembed parameter shapes should match after transpose: "
+ f"{hooked_unembed_param.T.shape} vs {bridge_unembed_param.shape}"
+ )
+
+ # Store initial parameters (should match since loaded from same checkpoint)
+ initial_hooked_unembed = hooked_unembed_param.data.clone()
+ initial_bridge_unembed = bridge_unembed_param.data.clone()
+ param_diff = (initial_hooked_unembed.T - initial_bridge_unembed).abs().max().item()
+ assert param_diff < 1e-4, (
+ f"Initial unembed parameters should match (loaded from same checkpoint). "
+ f"Max diff: {param_diff:.6e}"
+ )
+
+ # Track current step for threshold selection
+ current_step = 0
+
+ # Run optimization loop
+ for step_config in step_thresholds:
+ target_step = step_config.step
+
+ # Run optimization steps until we reach the target step
+ while current_step < target_step:
+ current_step += 1
+
+ # ===== INITIAL FORWARD PASS (before this step) =====
+ hooked_logits = hooked(input_ids, return_type="logits")
+ bridge_logits = bridge(input_ids, return_type="logits")
+
+ # Only validate initial forward on the target steps
+ if current_step == target_step:
+ logits_diff = (hooked_logits - bridge_logits).abs()
+ logits_max_diff = logits_diff.max().item()
+ logits_mean_diff = logits_diff.mean().item()
+
+ # Compare losses
+ hooked_loss = hooked_logits[0, -1].sum()
+ bridge_loss = bridge_logits[0, -1].sum()
+ loss_diff = abs(hooked_loss.item() - bridge_loss.item())
+ loss_relative_diff = loss_diff / (abs(hooked_loss.item()) + 1e-8)
+
+ assert logits_max_diff < step_config.initial_fwd.logits_max, (
+ f"Step {current_step}: Initial logits max diff {logits_max_diff:.6f} "
+ f"exceeds threshold {step_config.initial_fwd.logits_max:.6f}"
+ )
+ assert logits_mean_diff < step_config.initial_fwd.logits_mean, (
+ f"Step {current_step}: Initial logits mean diff {logits_mean_diff:.6f} "
+ f"exceeds threshold {step_config.initial_fwd.logits_mean:.6f}"
+ )
+
+ assert loss_relative_diff < step_config.initial_fwd.loss_relative, (
+ f"Step {current_step}: Initial loss relative diff {loss_relative_diff:.6f} "
+ f"exceeds threshold {step_config.initial_fwd.loss_relative:.6f}"
+ )
+
+ # Compute loss for backward
+ hooked_loss = hooked_logits[0, -1].sum()
+ bridge_loss = bridge_logits[0, -1].sum()
+
+ # ===== BACKWARD PASS =====
+ hooked_loss.backward()
+ bridge_loss.backward()
+
+ # Verify gradients exist and are reasonable (only on target steps)
+ if current_step == target_step:
+ assert (
+ hooked_unembed_param.grad is not None
+ ), "HookedTransformer unembed should have gradients"
+ assert bridge_unembed_param.grad is not None, "Bridge unembed should have gradients"
+
+ hooked_grad_mag = hooked_unembed_param.grad.abs().mean().item()
+ bridge_grad_mag = bridge_unembed_param.grad.abs().mean().item()
+
+ assert hooked_grad_mag > 1e-6 and hooked_grad_mag < 1e6, (
+ f"Step {current_step}: HookedTransformer gradients should be reasonable: "
+ f"{hooked_grad_mag:.6e}"
+ )
+ assert bridge_grad_mag > 1e-6 and bridge_grad_mag < 1e6, (
+ f"Step {current_step}: Bridge gradients should be reasonable: "
+ f"{bridge_grad_mag:.6e}"
+ )
+
+ # Store parameters before update (for validation on target steps)
+ if current_step == target_step:
+ hooked_unembed_before = hooked_unembed_param.data.clone()
+ bridge_unembed_before = bridge_unembed_param.data.clone()
+
+ # ===== OPTIMIZER STEP =====
+ hooked_optimizer.step()
+ bridge_optimizer.step()
+
+ # ===== VALIDATE PARAMETER UPDATES (on target steps) =====
+ if current_step == target_step:
+ hooked_unembed_after = hooked_unembed_param.data
+ bridge_unembed_after = bridge_unembed_param.data
+
+ # Verify parameters were updated
+ hooked_delta = hooked_unembed_after - hooked_unembed_before
+ bridge_delta = bridge_unembed_after - bridge_unembed_before
+ assert (
+ hooked_delta.abs().max() > 1e-8
+ ), f"Step {current_step}: HookedTransformer unembed should be updated"
+ assert (
+ bridge_delta.abs().max() > 1e-8
+ ), f"Step {current_step}: Bridge unembed should be updated"
+
+ # Verify parameters remain close
+ param_diff = (hooked_unembed_after.T - bridge_unembed_after).abs()
+ param_max_diff = param_diff.max().item()
+ param_mean_diff = param_diff.mean().item()
+
+ assert param_max_diff < step_config.param_update.params_max, (
+ f"Step {current_step}: Parameter max diff {param_max_diff:.6e} "
+ f"exceeds threshold {step_config.param_update.params_max:.6e}"
+ )
+ assert param_mean_diff < step_config.param_update.params_mean, (
+ f"Step {current_step}: Parameter mean diff {param_mean_diff:.6e} "
+ f"exceeds threshold {step_config.param_update.params_mean:.6e}"
+ )
+
+ # Zero gradients for next iteration
+ hooked_optimizer.zero_grad()
+ bridge_optimizer.zero_grad()
+
+ # ===== POST-UPDATE FORWARD PASS (on target steps) =====
+ if current_step == target_step:
+ with torch.no_grad():
+ hooked_logits_after = hooked(input_ids, return_type="logits")
+ bridge_logits_after = bridge(input_ids, return_type="logits")
+
+ logits_diff_after = (hooked_logits_after - bridge_logits_after).abs()
+ logits_max_diff_after = logits_diff_after.max().item()
+ logits_mean_diff_after = logits_diff_after.mean().item()
+
+ assert logits_max_diff_after < step_config.post_update_fwd.logits_max, (
+ f"Step {current_step}: Post-update logits max diff {logits_max_diff_after:.6f} "
+ f"exceeds threshold {step_config.post_update_fwd.logits_max:.6f}"
+ )
+ assert logits_mean_diff_after < step_config.post_update_fwd.logits_mean, (
+ f"Step {current_step}: Post-update logits mean diff {logits_mean_diff_after:.6f} "
+ f"exceeds threshold {step_config.post_update_fwd.logits_mean:.6f}"
+ )
+
+ # Compare losses after update
+ hooked_loss_after = hooked_logits_after[0, -1].sum()
+ bridge_loss_after = bridge_logits_after[0, -1].sum()
+ loss_diff_after = abs(hooked_loss_after.item() - bridge_loss_after.item())
+ loss_relative_diff_after = loss_diff_after / (abs(hooked_loss_after.item()) + 1e-8)
+
+ assert loss_relative_diff_after < step_config.post_update_fwd.loss_relative, (
+ f"Step {current_step}: Post-update loss relative diff "
+ f"{loss_relative_diff_after:.6f} exceeds threshold "
+ f"{step_config.post_update_fwd.loss_relative:.6f}"
+ )
diff --git a/tests/unit/components/test_attention.py b/tests/unit/components/test_attention.py
index b386660c6..c473cc491 100644
--- a/tests/unit/components/test_attention.py
+++ b/tests/unit/components/test_attention.py
@@ -80,6 +80,25 @@ def test_attention_load_in_4bit():
assert torch.all(attn.b_V == 0)
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for half/bfloat16 tests")
+@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
+def test_attention_forward_half_precisions(dtype):
+ # Construct a small attention block
+ cfg = HookedTransformerConfig(
+ d_model=64, d_head=16, n_heads=4, n_layers=1, n_ctx=8, dtype=dtype
+ )
+ attn = Attention(cfg)
+ # Random inputs in the matching dtype
+ batch = 1
+ seq = 4
+ x = torch.rand((batch, seq, cfg.d_model), dtype=dtype).to("cuda")
+ # Run forward through attention (q,k,v = x)
+ out = attn(x, x, x)
+ # Should not raise and return a tensor on cuda with same dtype as cfg or compatible
+ assert isinstance(out, torch.Tensor)
+ assert out.device.type == "cuda"
+
+
def test_attention_config_dict():
cfg = {
"n_layers": 12,
diff --git a/tests/unit/factored_matrix/test_multiply_by_scalar.py b/tests/unit/factored_matrix/test_multiply_by_scalar.py
index 85d0bfbe7..d5fbf29ba 100644
--- a/tests/unit/factored_matrix/test_multiply_by_scalar.py
+++ b/tests/unit/factored_matrix/test_multiply_by_scalar.py
@@ -23,6 +23,7 @@
), # Non-scalar Tensor. AssertionError expected.
(torch.rand(2), AssertionError), # Non-scalar Tensor. AssertionError expected.
],
+ ids=["tensor", "float", "int", "tensor_2d", "tensor_1d"],
)
@pytest.mark.parametrize("leading_dim", [False, True])
@pytest.mark.parametrize("multiply_from_left", [False, True])
diff --git a/tests/unit/model_bridge/test_optimizer_compatibility.py b/tests/unit/model_bridge/test_optimizer_compatibility.py
new file mode 100644
index 000000000..ea4056e4b
--- /dev/null
+++ b/tests/unit/model_bridge/test_optimizer_compatibility.py
@@ -0,0 +1,255 @@
+"""Tests for TransformerBridge optimizer compatibility.
+
+Ensures that TransformerBridge.parameters() returns only leaf tensors
+that are compatible with PyTorch optimizers.
+"""
+
+import pytest
+import torch
+from torch import nn
+
+from transformer_lens.model_bridge.bridge import TransformerBridge
+
+
+@pytest.fixture
+def small_bridge_model():
+ """Create a small TransformerBridge model for testing."""
+ model_name = "distilgpt2" # Use smaller model for faster tests
+ bridge = TransformerBridge.boot_transformers(model_name, device="cpu")
+
+ if bridge.tokenizer.pad_token is None:
+ bridge.tokenizer.pad_token = bridge.tokenizer.eos_token
+
+ return bridge
+
+
+class TestParametersAreLeafTensors:
+ """Test that parameters() returns only leaf tensors."""
+
+ def test_all_parameters_are_leaf_tensors(self, small_bridge_model):
+ """Verify all parameters returned by parameters() are leaf tensors."""
+ for i, param in enumerate(small_bridge_model.parameters()):
+ assert param.is_leaf, (
+ f"Parameter {i} is non-leaf (has grad_fn={param.grad_fn}). "
+ "Non-leaf tensors cannot be optimized."
+ )
+ assert isinstance(param, nn.Parameter), f"Parameter {i} is not an nn.Parameter"
+
+ def test_tl_parameters_provides_tl_style_names(self, small_bridge_model):
+ """Verify tl_parameters() provides TransformerLens-style parameter dictionary.
+
+ tl_parameters() returns processed weights for analysis (via SVDInterpreter, etc.).
+ These may include non-leaf tensors created by einops.rearrange(), which is
+ expected and correct for TransformerLens compatibility.
+
+ For optimization, use parameters() which returns only leaf tensors.
+ """
+ # Get TL-style parameters
+ tl_params = small_bridge_model.tl_parameters()
+
+ # Check that we have TL-style names (blocks.X.attn.W_Y format)
+ assert any(
+ "blocks." in name and ".attn." in name for name in tl_params.keys()
+ ), "Expected TransformerLens-style parameter names like 'blocks.0.attn.W_Q'"
+
+ # Check that some common TL parameter names exist
+ assert any(
+ name.endswith(".W_E") for name in tl_params.keys()
+ ), "Expected embedding parameter 'W_E'"
+
+ def test_tl_named_parameters_provides_iterator(self, small_bridge_model):
+ """Verify tl_named_parameters() provides iterator with TL-style names.
+
+ This method provides the same content as tl_parameters() but as an iterator,
+ maintaining consistency with PyTorch's named_parameters() API pattern.
+ """
+ # Get TL-style parameters as iterator
+ tl_named_params = list(small_bridge_model.tl_named_parameters())
+ tl_params_dict = small_bridge_model.tl_parameters()
+
+ # Verify iterator returns same content as dictionary
+ assert len(tl_named_params) == len(
+ tl_params_dict
+ ), "Iterator should yield same number of parameters as dict"
+
+ # Verify names and tensors match
+ iterator_dict = dict(tl_named_params)
+ for name, tensor in tl_params_dict.items():
+ assert name in iterator_dict, f"Name {name} should be in iterator output"
+ assert torch.equal(iterator_dict[name], tensor), f"Tensor for {name} should match"
+
+ # Check that we have TL-style names (blocks.X.attn.W_Y format)
+ param_names = [name for name, _ in tl_named_params]
+ assert any(
+ "blocks." in name and ".attn." in name for name in param_names
+ ), "Expected TransformerLens-style parameter names like 'blocks.0.attn.W_Q'"
+
+ def test_no_processed_weights_in_parameters(self, small_bridge_model):
+ """Verify processed weight attributes are not included in parameters().
+
+ Note: This test verifies that parameters() (e.g. for optimizers) doesn't include
+ internal processed weight attributes. These weights should only appear in
+ tl_parameters().
+ """
+ # Enable compatibility mode to create processed weights
+ small_bridge_model.enable_compatibility_mode(no_processing=True)
+
+ # Get all parameter names from PyTorch-style named_parameters()
+ hf_param_names = {name for name, _ in small_bridge_model.named_parameters()}
+
+ # Check that processed weight attribute names are NOT in parameters
+ # (They should exist as attributes but not be trainable parameters)
+ for block_idx in range(small_bridge_model.cfg.n_layers):
+ # These are the processed weight attributes created by _set_processed_weight_attributes
+ processed_weight_attrs = [
+ f"blocks.{block_idx}.attn._processed_W_Q",
+ f"blocks.{block_idx}.attn._processed_W_K",
+ f"blocks.{block_idx}.attn._processed_W_V",
+ f"blocks.{block_idx}.attn._processed_W_O",
+ ]
+
+ for attr_name in processed_weight_attrs:
+ # The attribute might exist on the object but should NOT be in parameters()
+ assert attr_name not in hf_param_names, (
+ f"Processed weight attribute '{attr_name}' should not be in parameters(). "
+ "Processed weights are views for analysis, not trainable parameters."
+ )
+
+
+class TestOptimizerCompatibility:
+ """Test that TransformerBridge works with standard PyTorch optimizers."""
+
+ def test_adamw_accepts_parameters(self, small_bridge_model):
+ """Test that AdamW optimizer accepts TransformerBridge parameters."""
+ # This should not raise "can't optimize a non-leaf Tensor"
+ try:
+ optimizer = torch.optim.AdamW(small_bridge_model.parameters(), lr=1e-4)
+ assert optimizer is not None
+ except ValueError as e:
+ if "can't optimize a non-leaf Tensor" in str(e):
+ pytest.fail(
+ "AdamW rejected TransformerBridge parameters. "
+ "This indicates non-leaf tensors are being returned by parameters()."
+ )
+ raise
+
+ def test_gradient_flow_after_backward(self, small_bridge_model):
+ """Test that gradients flow correctly after backward pass."""
+ small_bridge_model.train()
+ input_ids = torch.randint(0, small_bridge_model.cfg.d_vocab, (1, 10))
+ logits = small_bridge_model(input_ids, return_type="logits")
+ loss = logits.sum()
+ loss.backward()
+
+ # Verify that parameters have gradients
+ params_with_grad = 0
+ total_params = 0
+
+ for param in small_bridge_model.parameters():
+ total_params += 1
+ if param.grad is not None:
+ params_with_grad += 1
+ # Verify gradient is on a leaf tensor
+ assert param.is_leaf, "Gradient was computed for a non-leaf tensor"
+
+ # At least some parameters should have gradients
+ assert params_with_grad > 0, "No parameters received gradients after backward pass"
+
+ def test_optimizer_step_updates_parameters(self, small_bridge_model):
+ """Test that optimizer.step() actually updates model parameters."""
+ small_bridge_model.train()
+ optimizer = torch.optim.SGD(small_bridge_model.parameters(), lr=0.1)
+
+ # Get initial parameter values (first few params for efficiency)
+ initial_params = {}
+ for i, (name, param) in enumerate(small_bridge_model.named_parameters()):
+ if i >= 5: # Just check first 5 parameters
+ break
+ initial_params[name] = param.data.clone()
+
+ # Create dummy input and compute loss
+ input_ids = torch.randint(0, small_bridge_model.cfg.d_vocab, (1, 10))
+ logits = small_bridge_model(input_ids, return_type="logits")
+ loss = logits.sum()
+
+ # Backward and step
+ loss.backward()
+ optimizer.step()
+
+ # Verify parameters were updated
+ params_updated = 0
+ for name, initial_value in initial_params.items():
+ current_value = dict(small_bridge_model.named_parameters())[name].data
+
+ # Check if parameter changed
+ if not torch.allclose(initial_value, current_value, atol=1e-8):
+ params_updated += 1
+
+ assert params_updated > 0, (
+ "No parameters were updated after optimizer.step(). "
+ "This suggests the optimizer is not correctly connected to the model parameters."
+ )
+
+
+class TestParametersAfterCompatibilityMode:
+ """Test parameters() behavior after enabling compatibility mode."""
+
+ def test_parameters_still_leaf_after_compatibility_mode(self, small_bridge_model):
+ """Verify parameters() returns leaf tensors even after enabling compatibility mode."""
+ # Enable compatibility mode (which creates processed weights)
+ small_bridge_model.enable_compatibility_mode(no_processing=True)
+
+ # All parameters from parameters() should still be leaf tensors
+ # (named_parameters() may include non-leaf processed weights for TL compatibility)
+ for i, param in enumerate(small_bridge_model.parameters()):
+ assert param.is_leaf, (
+ f"Parameter {i} from parameters() is non-leaf after compatibility mode. "
+ "Compatibility mode should not affect trainable parameters from parameters()."
+ )
+
+ def test_optimizer_works_after_compatibility_mode(self, small_bridge_model):
+ """Test that optimizers still work after enabling compatibility mode."""
+ # Enable compatibility mode
+ small_bridge_model.enable_compatibility_mode(no_processing=True)
+
+ # Should still be able to create optimizer
+ try:
+ optimizer = torch.optim.AdamW(small_bridge_model.parameters(), lr=1e-4)
+ assert optimizer is not None
+ except ValueError as e:
+ if "can't optimize a non-leaf Tensor" in str(e):
+ pytest.fail(
+ "AdamW rejected parameters after compatibility mode. "
+ "This indicates non-leaf tensors are being returned."
+ )
+ raise
+
+
+class TestParametersMatchOriginalModel:
+ """Test that parameters() returns the same parameters as the original HF model."""
+
+ def test_parameter_count_matches(self, small_bridge_model):
+ """Verify parameter count matches original model."""
+ bridge_param_count = sum(1 for _ in small_bridge_model.parameters())
+ original_param_count = sum(1 for _ in small_bridge_model.original_model.parameters())
+
+ assert bridge_param_count == original_param_count, (
+ f"Parameter count mismatch: Bridge has {bridge_param_count}, "
+ f"original model has {original_param_count}"
+ )
+
+ def test_parameters_are_same_objects(self, small_bridge_model):
+ """Verify that parameters() returns the actual original model parameters."""
+ bridge_params = list(small_bridge_model.parameters())
+ original_params = list(small_bridge_model.original_model.parameters())
+
+ # Should have same number of parameters
+ assert len(bridge_params) == len(original_params)
+
+ # Parameters should be the same objects (same id)
+ # This ensures gradients flow to the original model
+ for bridge_param, original_param in zip(bridge_params, original_params):
+ assert bridge_param is original_param, (
+ "Bridge parameters should be the exact same objects as original model parameters. "
+ "This ensures gradient flow and memory efficiency."
+ )
diff --git a/transformer_lens/SVDInterpreter.py b/transformer_lens/SVDInterpreter.py
index e0814cc79..bf69beec4 100644
--- a/transformer_lens/SVDInterpreter.py
+++ b/transformer_lens/SVDInterpreter.py
@@ -20,7 +20,12 @@ class SVDInterpreter:
def __init__(self, model: Any):
self.model = model
self.cfg = model.cfg
- self.params = {name: param for name, param in model.named_parameters()}
+ # Use tl_parameters() for TransformerBridge (returns TL-style dict)
+ # Fall back to named_parameters() for HookedTransformer
+ if hasattr(model, "tl_parameters"):
+ self.params = model.tl_parameters()
+ else:
+ self.params = {name: param for name, param in model.named_parameters()}
@typechecked
def get_singular_vectors(
diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py
index 5f026f493..0d144a741 100644
--- a/transformer_lens/components/abstract_attention.py
+++ b/transformer_lens/components/abstract_attention.py
@@ -8,6 +8,7 @@
import torch.nn.functional as F
from better_abc import abstract_attribute
from jaxtyping import Float, Int
+from torch import Tensor
from transformers.utils.import_utils import is_bitsandbytes_available
from transformer_lens.cache.key_value_cache_entry import (
@@ -280,8 +281,7 @@ def forward(
raise TypeError(f"Expected 'pattern' to be a Tensor, got {type(pattern)}")
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
- pattern = pattern.to(self.cfg.dtype)
- pattern = pattern.to(v.device)
+ pattern = pattern.to(device=v.device, dtype=v.dtype)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
if self.cfg.load_in_4bit:
@@ -301,15 +301,21 @@ def forward(
self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
)
- if self.b_O.device != w.device:
- w = w.to(self.b_O.device)
- if self.b_O.device != z.device:
- z = z.to(self.b_O.device)
+ # Move output projection weights and bias to the same device as z
+ # so that the final linear operation occurs on the device of the inputs
+ if w.device != z.device:
+ w = w.to(z.device)
+ b_O: Tensor = self.b_O
+ if b_O.device != z.device:
+ b_O = b_O.to(z.device)
+ # Ensure z has the same dtype as weights used in the output projection
+ if z.dtype != w.dtype:
+ z = z.to(w.dtype)
out = F.linear(
z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
w,
- self.b_O,
+ b_O,
)
else:
# Explicitly calculate the attention result so it can be accessed by a hook
@@ -329,6 +335,11 @@ def forward(
self.W_O,
"head_index d_head d_model -> 1 1 head_index d_head d_model",
)
+ if w.device != z.device:
+ w = w.to(z.device)
+ # Ensure z has the same dtype as w before multiplication
+ if z.dtype != w.dtype:
+ z = z.to(w.dtype)
z = einops.rearrange(
z, "batch pos head_index d_head -> batch pos head_index d_head 1"
)
diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py
index 141af5e2e..0fd77ce8d 100644
--- a/transformer_lens/model_bridge/bridge.py
+++ b/transformer_lens/model_bridge/bridge.py
@@ -39,6 +39,7 @@
)
from transformer_lens.model_bridge.get_params_util import get_bridge_params
from transformer_lens.utilities.aliases import resolve_alias
+from transformer_lens.utilities.devices import move_to_and_update_config
if TYPE_CHECKING:
from transformer_lens.ActivationCache import ActivationCache
@@ -1078,17 +1079,73 @@ def QK(self):
def OV(self):
return FactoredMatrix(self.W_V, self.W_O)
+ def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
+ """Returns parameters following standard PyTorch semantics.
+
+ This method delegates to the underlying HuggingFace model's parameters().
+ For TransformerLens-style parameter generator, use tl_parameters() instead.
+
+ Args:
+ recurse: If True, yields parameters of this module and all submodules
+
+ Returns:
+ Iterator of nn.Parameter objects
+ """
+ return self.original_model.parameters(recurse=recurse)
+
def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
- ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
- """Return named parameters in the same format as TransformerLens.
+ ) -> Iterator[tuple[str, nn.Parameter]]:
+ """Returns named parameters following standard PyTorch semantics.
+
+ This method delegates to the underlying HuggingFace model's named_parameters().
+ For TransformerLens-style generator, use tl_named_parameters() instead.
+
+ Args:
+ prefix: Prefix to prepend to all parameter names
+ recurse: If True, yields parameters of this module and all submodules
+ remove_duplicate: If True, removes duplicate parameters
+
+ Returns:
+ Iterator of (name, parameter) tuples
+ """
+ return self.original_model.named_parameters(prefix, recurse, remove_duplicate)
+
+ def tl_parameters(self) -> dict[str, torch.Tensor]:
+ """Returns TransformerLens-style parameter dictionary.
- This ensures compatibility with tools like SVDInterpreter that expect
- parameter names like 'blocks.0.attn.W_Q' instead of the raw model names.
+ Parameter names follow TransformerLens conventions (e.g., 'blocks.0.attn.W_Q') and may
+ include processed weights (non-leaf tensors). This format is expected by SVDInterpreter
+ among other analysis tools.
+
+ Returns:
+ Dictionary mapping TransformerLens parameter names to tensors
+
+ Example:
+ >>> bridge = TransformerBridge.boot_transformers("gpt2")
+ >>> tl_params = bridge.tl_parameters()
+ >>> W_Q = tl_params["blocks.0.attn.W_Q"] # Shape: [n_heads, d_model, d_head]
+ """
+ return self.get_params()
+
+ def tl_named_parameters(self) -> Iterator[tuple[str, torch.Tensor]]:
+ """Returns iterator of TransformerLens-style named parameters.
+
+ This provides the same parameters as tl_parameters() but as an iterator
+ for consistency with PyTorch's named_parameters() API pattern.
+
+ Returns:
+ Iterator of (name, tensor) tuples with TransformerLens naming conventions
+
+ Example:
+ >>> bridge = TransformerBridge.boot_transformers("gpt2")
+ >>> for name, param in bridge.tl_named_parameters():
+ ... if "attn.W_Q" in name:
+ ... print(f"{name}: {param.shape}") # doctest: +ELLIPSIS
+ blocks.0.attn.W_Q: torch.Size([12, 768, 64])
+ ...
"""
- params_dict = self.get_params()
- for name, param in params_dict.items():
- yield (name, param)
+ return iter(self.get_params().items())
def forward(
self,
@@ -1754,7 +1811,7 @@ def generate(
return output_tokens
def to(self, *args, **kwargs) -> "TransformerBridge":
- """Move model to device or change dtype.
+ """Move model to device and/or change dtype.
Args:
args: Positional arguments for nn.Module.to
@@ -1763,6 +1820,37 @@ def to(self, *args, **kwargs) -> "TransformerBridge":
Returns:
Self for chaining
"""
+ # Extract print_details if provided
+ print_details = kwargs.pop("print_details", True)
+
+ # Handle both device and dtype changes
+ # torch.nn.Module.to() supports: to(device), to(dtype), to(device, dtype),
+ # to(device=...), to(dtype=...), to(device=..., dtype=...)
+ target_device, target_dtype = None, None
+
+ if len(args) >= 1:
+ first_arg = args[0]
+ if isinstance(first_arg, (torch.device, str)):
+ target_device = first_arg
+ elif isinstance(first_arg, torch.dtype):
+ target_dtype = first_arg
+ if len(args) >= 2:
+ second_arg = args[1]
+ if isinstance(second_arg, torch.dtype):
+ target_dtype = second_arg
+
+ # these override positional args
+ if "device" in kwargs:
+ target_device = kwargs["device"]
+ if "dtype" in kwargs:
+ target_dtype = kwargs["dtype"]
+
+ if target_device is not None:
+ move_to_and_update_config(self, target_device, print_details)
+ if target_dtype is not None:
+ move_to_and_update_config(self, target_dtype, print_details)
+
+ # Move the original model with all original args/kwargs (with print_details removed)
self.original_model = self.original_model.to(*args, **kwargs)
return self